Goit

Simple and lightweight Git web server
git clone http://git.omkov.net/Goit
Log | Tree | Refs | README | Download

AuthorJakob Wakeling <[email protected]>
Date2023-10-23 02:14:42
Commit410f65bf45bdb8f9a277fe7da1ddb29f30fbd3cb
Parentaa9a5d5f7cb73e6ad41b6f5e346fd3ba12efc0f7

Add a mutex lock to the sessions map

Diffstat

M Makefile | 3 ++-
M main.go | 4 +++-
M src/auth.go | 63 ++++++++++++++++++++++++++++++++++++++++++++++++---------------
A src/auth_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
M src/user.go | 50 --------------------------------------------------
A src/user/sessions.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

6 files changed, 180 insertions, 67 deletions

diff --git a/Makefile b/Makefile
index 0f0fa25..4f0153a 100644
--- a/Makefile
+++ b/Makefile
@@ -1,11 +1,12 @@
 .PHONY: all build test help
 all: help
 
+MODULE = "github.com/Jamozed/Goit"
 PROGRAM = "goit"
 VERSION = "0.0.0"
 
 build: ## Build the project
-	@go build -ldflags "-X res.Version=$(VERSION)" -o ./bin/$(PROGRAM) .
+	@go build -ldflags "-X $(MODULE)/res.Version=$(VERSION)" -o ./bin/$(PROGRAM) .
 
 test: ## Run unit tests
 	@go test ./...
diff --git a/main.go b/main.go
index e292e5b..eb37c8b 100644
--- a/main.go
+++ b/main.go
@@ -18,6 +18,8 @@ import (
 )
 
 func main() {
+	log.Println("Starting Goit", res.Version)
+
 	if err := goit.Goit(goit.ConfPath()); err != nil {
 		log.Fatalln(err.Error())
 	}
@@ -28,7 +30,7 @@ func main() {
 	h.Path("/").HandlerFunc(goit.HandleIndex)
 	h.Path("/user/login").Methods("GET", "POST").HandlerFunc(user.HandleLogin)
 	h.Path("/user/logout").Methods("GET", "POST").HandlerFunc(goit.HandleUserLogout)
-	h.Path("/user/sessions").Methods("GET", "POST").HandlerFunc(goit.HandleUserSessions)
+	h.Path("/user/sessions").Methods("GET", "POST").HandlerFunc(user.HandleSessions)
 	h.Path("/user/edit").Methods("GET", "POST").HandlerFunc(user.HandleEdit)
 	h.Path("/repo/create").Methods("GET", "POST").HandlerFunc(repo.HandleCreate)
 	h.Path("/repo/delete").Methods("DELETE").HandlerFunc(repo.HandleDelete)
diff --git a/src/auth.go b/src/auth.go
index eb63ed9..27158d0 100644
--- a/src/auth.go
+++ b/src/auth.go
@@ -13,6 +13,7 @@ import (
 	"net/http"
 	"strconv"
 	"strings"
+	"sync"
 	"time"
 
 	"github.com/Jamozed/Goit/src/util"
@@ -25,25 +26,38 @@ type Session struct {
 }
 
 var Sessions = map[int64][]Session{}
+var SessionsMutex = sync.RWMutex{}
 
+/* Generate a new user session. */
 func NewSession(uid int64, ip string, expiry time.Time) (Session, error) {
-	b := make([]byte, 24)
+	var b = make([]byte, 24)
 	if _, err := rand.Read(b); err != nil {
 		return Session{}, err
 	}
 
+	var t = base64.StdEncoding.EncodeToString(b)
+	var s = Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry}
+
+	SessionsMutex.Lock()
 	if Sessions[uid] == nil {
 		Sessions[uid] = []Session{}
 	}
 
-	t := base64.StdEncoding.EncodeToString(b)
-	s := Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry}
-
 	Sessions[uid] = append(Sessions[uid], s)
+	SessionsMutex.Unlock()
+
 	return s, nil
 }
 
+/* End a user session. */
 func EndSession(uid int64, token string) {
+	SessionsMutex.Lock()
+	defer SessionsMutex.Unlock()
+
+	if Sessions[uid] == nil {
+		return
+	}
+
 	for i, t := range Sessions[uid] {
 		if t.Token == token {
 			Sessions[uid] = append(Sessions[uid][:i], Sessions[uid][i+1:]...)
@@ -56,23 +70,36 @@ func EndSession(uid int64, token string) {
 	}
 }
 
+/* Cleanup expired user sessions. */
 func CleanupSessions() {
-	var n uint64 = 0
-
-	for k, v := range Sessions {
-		for _, v1 := range v {
-			if v1.Expiry.Before(time.Now()) {
-				EndSession(k, v1.Token)
-				n += 1
+	var n int = 0
+
+	SessionsMutex.Lock()
+	for uid, v := range Sessions {
+		var i = 0
+		for _, s := range v {
+			if s.Expiry.After(time.Now()) {
+				v[i] = s
+				i += 1
 			}
 		}
+
+		n += len(v) - i
+
+		if i == 0 {
+			delete(Sessions, uid)
+		} else {
+			Sessions[uid] = v[:i]
+		}
 	}
+	SessionsMutex.Unlock()
 
 	if n > 0 {
 		log.Println("[Cleanup] cleaned up", n, "expired sessions")
 	}
 }
 
+/* Set a user session cookie. */
 func SetSessionCookie(w http.ResponseWriter, uid int64, s Session) {
 	c := &http.Cookie{Name: "session", Value: fmt.Sprint(uid) + "." + s.Token, Path: "/", Expires: s.Expiry}
 	if err := c.Valid(); err != nil {
@@ -82,6 +109,7 @@ func SetSessionCookie(w http.ResponseWriter, uid int64, s Session) {
 	http.SetCookie(w, c)
 }
 
+/* Get a user session cookie if one is present. */
 func GetSessionCookie(r *http.Request) (int64, Session) {
 	if c := util.Cookie(r, "session"); c != nil {
 		ss := strings.SplitN(c.Value, ".", 2)
@@ -89,32 +117,36 @@ func GetSessionCookie(r *http.Request) (int64, Session) {
 			return -1, Session{}
 		}
 
-		id, err := strconv.ParseInt(ss[0], 10, 64)
+		uid, err := strconv.ParseInt(ss[0], 10, 64)
 		if err != nil {
 			return -1, Session{}
 		}
 
-		for i, s := range Sessions[id] {
+		SessionsMutex.Lock()
+		for i, s := range Sessions[uid] {
 			if ss[1] == s.Token {
 				if s != (Session{}) {
 					s.Seen = time.Now()
-					Sessions[id][i] = s
+					Sessions[uid][i] = s
 				}
 
-				return id, s
+				return uid, s
 			}
 		}
+		SessionsMutex.Unlock()
 
-		return id, Session{}
+		return uid, Session{}
 	}
 
 	return -1, Session{}
 }
 
+/* End the current user session cookie. */
 func EndSessionCookie(w http.ResponseWriter) {
 	http.SetCookie(w, &http.Cookie{Name: "session", Path: "/", MaxAge: -1})
 }
 
+/* Authenticate a user session cookie. */
 func AuthCookie(w http.ResponseWriter, r *http.Request, renew bool) (bool, int64) {
 	if uid, s := GetSessionCookie(r); s != (Session{}) {
 		if s.Expiry.After(time.Now()) {
@@ -138,6 +170,7 @@ func AuthCookie(w http.ResponseWriter, r *http.Request, renew bool) (bool, int64
 	return false, -1
 }
 
+/* Authenticate a user session cookie and check admin status. */
 func AuthCookieAdmin(w http.ResponseWriter, r *http.Request, renew bool) (bool, bool, int64) {
 	if ok, uid := AuthCookie(w, r, renew); ok {
 		if user, err := GetUser(uid); err == nil && user.IsAdmin {
diff --git a/src/auth_test.go b/src/auth_test.go
new file mode 100644
index 0000000..7560460
--- /dev/null
+++ b/src/auth_test.go
@@ -0,0 +1,64 @@
+package goit_test
+
+import (
+	"fmt"
+	"slices"
+	"sync"
+	"testing"
+	"time"
+
+	goit "github.com/Jamozed/Goit/src"
+)
+
+func TestNewSession(t *testing.T) {
+	goit.Sessions = map[int64][]goit.Session{}
+	goit.SessionsMutex = sync.RWMutex{}
+
+	var uid int64 = 1
+	var session = goit.Session{Ip: "127.0.0.1", Expiry: time.Unix(0, 0)}
+
+	s, err := goit.NewSession(uid, session.Ip, session.Expiry)
+	if err != nil {
+		t.Fatal(err.Error())
+	}
+
+	if goit.Sessions[uid] == nil {
+		t.Fatal("UID slice not added to the sessions map")
+	}
+	if len(goit.Sessions[uid]) != 1 {
+		t.Fatal("Incorrect number of sessions added to the sessions map")
+	}
+	if s != goit.Sessions[uid][0] {
+		t.Fatal("Added and returned sessions do not match")
+	}
+	if s.Ip != session.Ip {
+		t.Fatal("Added session IP is incorrect")
+	}
+	if s.Expiry != session.Expiry {
+		t.Fatal("Added session expiry is incorrect")
+	}
+	if !s.Seen.Before(time.Now()) {
+		t.Fatal("Session seen time is in the future")
+	}
+	if len(s.Token) != 32 {
+		t.Fatal("Session token length is incorrect")
+	}
+	if goit.SessionsMutex.TryLock() == false {
+		t.Fatal("Sessions mutex was not unlocked")
+	}
+}
+
+func TestHash(t *testing.T) {
+	var pass = "password"
+	var salt = make([]byte, 16)
+	var hash = []byte{
+		0x00, 0xB1, 0xEE, 0xD9, 0xBE, 0xE6, 0xDC, 0x06, 0x41, 0xA5, 0x07, 0x71,
+		0x7D, 0xB7, 0x6B, 0x65, 0x20, 0xEC, 0x87, 0x6E, 0xCE, 0x6C, 0xD1, 0x09,
+		0x25, 0xE4, 0x38, 0x75, 0xB5, 0x43, 0x57, 0x5E,
+	}
+
+	if !slices.Equal(goit.Hash(pass, salt), hash) {
+		fmt.Printf("%x", goit.Hash(pass, salt))
+		t.Fatal("Hash output is incorrect")
+	}
+}
diff --git a/src/user.go b/src/user.go
index 967ab34..3b523ef 100644
--- a/src/user.go
+++ b/src/user.go
@@ -8,13 +8,8 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
-	"log"
 	"net/http"
-	"strconv"
 	"strings"
-	"time"
-
-	"github.com/Jamozed/Goit/src/util"
 )
 
 type User struct {
@@ -36,51 +31,6 @@ func HandleUserLogout(w http.ResponseWriter, r *http.Request) {
 	http.Redirect(w, r, "/", http.StatusFound)
 }
 
-func HandleUserSessions(w http.ResponseWriter, r *http.Request) {
-	auth, uid := AuthCookie(w, r, true)
-	if !auth {
-		HttpError(w, http.StatusUnauthorized)
-		return
-	}
-
-	_, ss := GetSessionCookie(r)
-
-	revoke, err := strconv.ParseInt(r.FormValue("revoke"), 10, 64)
-	if err != nil {
-		revoke = -1
-	}
-	if revoke >= 0 && revoke < int64(len(Sessions[uid])) {
-		current := Sessions[uid][revoke].Token == ss.Token
-		EndSession(uid, Sessions[uid][revoke].Token)
-
-		if current {
-			EndSessionCookie(w)
-			http.Redirect(w, r, "/", http.StatusFound)
-			return
-		}
-
-		http.Redirect(w, r, "/user/sessions", http.StatusFound)
-		return
-	}
-
-	type row struct{ Index, Ip, Seen, Expiry, Current string }
-	data := struct {
-		Title    string
-		Sessions []row
-	}{Title: "User - Sessions"}
-
-	for i, v := range Sessions[uid] {
-		data.Sessions = append(data.Sessions, row{
-			Index: fmt.Sprint(i), Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime),
-			Current: util.If(v.Token == ss.Token, "(current)", ""),
-		})
-	}
-
-	if err := Tmpl.ExecuteTemplate(w, "user/sessions", data); err != nil {
-		log.Println("[/user/login]", err.Error())
-	}
-}
-
 func GetUser(id int64) (*User, error) {
 	u := User{}
 
diff --git a/src/user/sessions.go b/src/user/sessions.go
new file mode 100644
index 0000000..6f20118
--- /dev/null
+++ b/src/user/sessions.go
@@ -0,0 +1,63 @@
+package user
+
+import (
+	"fmt"
+	"log"
+	"net/http"
+	"strconv"
+	"time"
+
+	goit "github.com/Jamozed/Goit/src"
+	"github.com/Jamozed/Goit/src/util"
+)
+
+func HandleSessions(w http.ResponseWriter, r *http.Request) {
+	auth, uid := goit.AuthCookie(w, r, true)
+	if !auth {
+		goit.HttpError(w, http.StatusUnauthorized)
+		return
+	}
+
+	_, ss := goit.GetSessionCookie(r)
+
+	revoke, err := strconv.ParseInt(r.FormValue("revoke"), 10, 64)
+	if err != nil {
+		revoke = -1
+	}
+
+	type row struct{ Index, Ip, Seen, Expiry, Current string }
+	var data = struct {
+		Title    string
+		Sessions []row
+	}{Title: "User - Sessions"}
+
+	goit.SessionsMutex.RLock()
+	if revoke >= 0 && revoke < int64(len(goit.Sessions[uid])) {
+		var token = goit.Sessions[uid][revoke].Token
+		var current = token == ss.Token
+
+		goit.SessionsMutex.RUnlock()
+		goit.EndSession(uid, token)
+
+		if current {
+			goit.EndSessionCookie(w)
+			http.Redirect(w, r, "/", http.StatusFound)
+			return
+		}
+
+		http.Redirect(w, r, "/user/sessions", http.StatusFound)
+		return
+	}
+
+	for i, v := range goit.Sessions[uid] {
+		data.Sessions = append(data.Sessions, row{
+			Index: fmt.Sprint(i), Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime),
+			Current: util.If(v.Token == ss.Token, "(current)", ""),
+		})
+	}
+	goit.SessionsMutex.RUnlock()
+
+	if err := goit.Tmpl.ExecuteTemplate(w, "user/sessions", data); err != nil {
+		log.Println("[/user/login]", err.Error())
+	}
+}