Author | Jakob Wakeling <[email protected]> |
Date | 2023-10-23 02:14:42 |
Commit | 410f65bf45bdb8f9a277fe7da1ddb29f30fbd3cb |
Parent | aa9a5d5f7cb73e6ad41b6f5e346fd3ba12efc0f7 |
Add a mutex lock to the sessions map
Diffstat
M | Makefile | | | 3 | ++- |
M | main.go | | | 4 | +++- |
M | src/auth.go | | | 63 | ++++++++++++++++++++++++++++++++++++++++++++++++--------------- |
A | src/auth_test.go | | | 64 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
M | src/user.go | | | 50 | -------------------------------------------------- |
A | src/user/sessions.go | | | 63 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
6 files changed, 180 insertions, 67 deletions
diff --git a/Makefile b/Makefile index 0f0fa25..4f0153a 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,12 @@ .PHONY: all build test help all: help +MODULE = "github.com/Jamozed/Goit" PROGRAM = "goit" VERSION = "0.0.0" build: ## Build the project - @go build -ldflags "-X res.Version=$(VERSION)" -o ./bin/$(PROGRAM) . + @go build -ldflags "-X $(MODULE)/res.Version=$(VERSION)" -o ./bin/$(PROGRAM) . test: ## Run unit tests @go test ./... diff --git a/main.go b/main.go index e292e5b..eb37c8b 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,8 @@ import ( ) func main() { + log.Println("Starting Goit", res.Version) + if err := goit.Goit(goit.ConfPath()); err != nil { log.Fatalln(err.Error()) } @@ -28,7 +30,7 @@ func main() { h.Path("/").HandlerFunc(goit.HandleIndex) h.Path("/user/login").Methods("GET", "POST").HandlerFunc(user.HandleLogin) h.Path("/user/logout").Methods("GET", "POST").HandlerFunc(goit.HandleUserLogout) - h.Path("/user/sessions").Methods("GET", "POST").HandlerFunc(goit.HandleUserSessions) + h.Path("/user/sessions").Methods("GET", "POST").HandlerFunc(user.HandleSessions) h.Path("/user/edit").Methods("GET", "POST").HandlerFunc(user.HandleEdit) h.Path("/repo/create").Methods("GET", "POST").HandlerFunc(repo.HandleCreate) h.Path("/repo/delete").Methods("DELETE").HandlerFunc(repo.HandleDelete) diff --git a/src/auth.go b/src/auth.go index eb63ed9..27158d0 100644 --- a/src/auth.go +++ b/src/auth.go @@ -13,6 +13,7 @@ import ( "net/http" "strconv" "strings" + "sync" "time" "github.com/Jamozed/Goit/src/util" @@ -25,25 +26,38 @@ type Session struct { } var Sessions = map[int64][]Session{} +var SessionsMutex = sync.RWMutex{} +/* Generate a new user session. */ func NewSession(uid int64, ip string, expiry time.Time) (Session, error) { - b := make([]byte, 24) + var b = make([]byte, 24) if _, err := rand.Read(b); err != nil { return Session{}, err } + var t = base64.StdEncoding.EncodeToString(b) + var s = Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry} + + SessionsMutex.Lock() if Sessions[uid] == nil { Sessions[uid] = []Session{} } - t := base64.StdEncoding.EncodeToString(b) - s := Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry} - Sessions[uid] = append(Sessions[uid], s) + SessionsMutex.Unlock() + return s, nil } +/* End a user session. */ func EndSession(uid int64, token string) { + SessionsMutex.Lock() + defer SessionsMutex.Unlock() + + if Sessions[uid] == nil { + return + } + for i, t := range Sessions[uid] { if t.Token == token { Sessions[uid] = append(Sessions[uid][:i], Sessions[uid][i+1:]...) @@ -56,23 +70,36 @@ func EndSession(uid int64, token string) { } } +/* Cleanup expired user sessions. */ func CleanupSessions() { - var n uint64 = 0 - - for k, v := range Sessions { - for _, v1 := range v { - if v1.Expiry.Before(time.Now()) { - EndSession(k, v1.Token) - n += 1 + var n int = 0 + + SessionsMutex.Lock() + for uid, v := range Sessions { + var i = 0 + for _, s := range v { + if s.Expiry.After(time.Now()) { + v[i] = s + i += 1 } } + + n += len(v) - i + + if i == 0 { + delete(Sessions, uid) + } else { + Sessions[uid] = v[:i] + } } + SessionsMutex.Unlock() if n > 0 { log.Println("[Cleanup] cleaned up", n, "expired sessions") } } +/* Set a user session cookie. */ func SetSessionCookie(w http.ResponseWriter, uid int64, s Session) { c := &http.Cookie{Name: "session", Value: fmt.Sprint(uid) + "." + s.Token, Path: "/", Expires: s.Expiry} if err := c.Valid(); err != nil { @@ -82,6 +109,7 @@ func SetSessionCookie(w http.ResponseWriter, uid int64, s Session) { http.SetCookie(w, c) } +/* Get a user session cookie if one is present. */ func GetSessionCookie(r *http.Request) (int64, Session) { if c := util.Cookie(r, "session"); c != nil { ss := strings.SplitN(c.Value, ".", 2) @@ -89,32 +117,36 @@ func GetSessionCookie(r *http.Request) (int64, Session) { return -1, Session{} } - id, err := strconv.ParseInt(ss[0], 10, 64) + uid, err := strconv.ParseInt(ss[0], 10, 64) if err != nil { return -1, Session{} } - for i, s := range Sessions[id] { + SessionsMutex.Lock() + for i, s := range Sessions[uid] { if ss[1] == s.Token { if s != (Session{}) { s.Seen = time.Now() - Sessions[id][i] = s + Sessions[uid][i] = s } - return id, s + return uid, s } } + SessionsMutex.Unlock() - return id, Session{} + return uid, Session{} } return -1, Session{} } +/* End the current user session cookie. */ func EndSessionCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{Name: "session", Path: "/", MaxAge: -1}) } +/* Authenticate a user session cookie. */ func AuthCookie(w http.ResponseWriter, r *http.Request, renew bool) (bool, int64) { if uid, s := GetSessionCookie(r); s != (Session{}) { if s.Expiry.After(time.Now()) { @@ -138,6 +170,7 @@ func AuthCookie(w http.ResponseWriter, r *http.Request, renew bool) (bool, int64 return false, -1 } +/* Authenticate a user session cookie and check admin status. */ func AuthCookieAdmin(w http.ResponseWriter, r *http.Request, renew bool) (bool, bool, int64) { if ok, uid := AuthCookie(w, r, renew); ok { if user, err := GetUser(uid); err == nil && user.IsAdmin { diff --git a/src/auth_test.go b/src/auth_test.go new file mode 100644 index 0000000..7560460 --- /dev/null +++ b/src/auth_test.go @@ -0,0 +1,64 @@ +package goit_test + +import ( + "fmt" + "slices" + "sync" + "testing" + "time" + + goit "github.com/Jamozed/Goit/src" +) + +func TestNewSession(t *testing.T) { + goit.Sessions = map[int64][]goit.Session{} + goit.SessionsMutex = sync.RWMutex{} + + var uid int64 = 1 + var session = goit.Session{Ip: "127.0.0.1", Expiry: time.Unix(0, 0)} + + s, err := goit.NewSession(uid, session.Ip, session.Expiry) + if err != nil { + t.Fatal(err.Error()) + } + + if goit.Sessions[uid] == nil { + t.Fatal("UID slice not added to the sessions map") + } + if len(goit.Sessions[uid]) != 1 { + t.Fatal("Incorrect number of sessions added to the sessions map") + } + if s != goit.Sessions[uid][0] { + t.Fatal("Added and returned sessions do not match") + } + if s.Ip != session.Ip { + t.Fatal("Added session IP is incorrect") + } + if s.Expiry != session.Expiry { + t.Fatal("Added session expiry is incorrect") + } + if !s.Seen.Before(time.Now()) { + t.Fatal("Session seen time is in the future") + } + if len(s.Token) != 32 { + t.Fatal("Session token length is incorrect") + } + if goit.SessionsMutex.TryLock() == false { + t.Fatal("Sessions mutex was not unlocked") + } +} + +func TestHash(t *testing.T) { + var pass = "password" + var salt = make([]byte, 16) + var hash = []byte{ + 0x00, 0xB1, 0xEE, 0xD9, 0xBE, 0xE6, 0xDC, 0x06, 0x41, 0xA5, 0x07, 0x71, + 0x7D, 0xB7, 0x6B, 0x65, 0x20, 0xEC, 0x87, 0x6E, 0xCE, 0x6C, 0xD1, 0x09, + 0x25, 0xE4, 0x38, 0x75, 0xB5, 0x43, 0x57, 0x5E, + } + + if !slices.Equal(goit.Hash(pass, salt), hash) { + fmt.Printf("%x", goit.Hash(pass, salt)) + t.Fatal("Hash output is incorrect") + } +} diff --git a/src/user.go b/src/user.go index 967ab34..3b523ef 100644 --- a/src/user.go +++ b/src/user.go @@ -8,13 +8,8 @@ import ( "database/sql" "errors" "fmt" - "log" "net/http" - "strconv" "strings" - "time" - - "github.com/Jamozed/Goit/src/util" ) type User struct { @@ -36,51 +31,6 @@ func HandleUserLogout(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) } -func HandleUserSessions(w http.ResponseWriter, r *http.Request) { - auth, uid := AuthCookie(w, r, true) - if !auth { - HttpError(w, http.StatusUnauthorized) - return - } - - _, ss := GetSessionCookie(r) - - revoke, err := strconv.ParseInt(r.FormValue("revoke"), 10, 64) - if err != nil { - revoke = -1 - } - if revoke >= 0 && revoke < int64(len(Sessions[uid])) { - current := Sessions[uid][revoke].Token == ss.Token - EndSession(uid, Sessions[uid][revoke].Token) - - if current { - EndSessionCookie(w) - http.Redirect(w, r, "/", http.StatusFound) - return - } - - http.Redirect(w, r, "/user/sessions", http.StatusFound) - return - } - - type row struct{ Index, Ip, Seen, Expiry, Current string } - data := struct { - Title string - Sessions []row - }{Title: "User - Sessions"} - - for i, v := range Sessions[uid] { - data.Sessions = append(data.Sessions, row{ - Index: fmt.Sprint(i), Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime), - Current: util.If(v.Token == ss.Token, "(current)", ""), - }) - } - - if err := Tmpl.ExecuteTemplate(w, "user/sessions", data); err != nil { - log.Println("[/user/login]", err.Error()) - } -} - func GetUser(id int64) (*User, error) { u := User{} diff --git a/src/user/sessions.go b/src/user/sessions.go new file mode 100644 index 0000000..6f20118 --- /dev/null +++ b/src/user/sessions.go @@ -0,0 +1,63 @@ +package user + +import ( + "fmt" + "log" + "net/http" + "strconv" + "time" + + goit "github.com/Jamozed/Goit/src" + "github.com/Jamozed/Goit/src/util" +) + +func HandleSessions(w http.ResponseWriter, r *http.Request) { + auth, uid := goit.AuthCookie(w, r, true) + if !auth { + goit.HttpError(w, http.StatusUnauthorized) + return + } + + _, ss := goit.GetSessionCookie(r) + + revoke, err := strconv.ParseInt(r.FormValue("revoke"), 10, 64) + if err != nil { + revoke = -1 + } + + type row struct{ Index, Ip, Seen, Expiry, Current string } + var data = struct { + Title string + Sessions []row + }{Title: "User - Sessions"} + + goit.SessionsMutex.RLock() + if revoke >= 0 && revoke < int64(len(goit.Sessions[uid])) { + var token = goit.Sessions[uid][revoke].Token + var current = token == ss.Token + + goit.SessionsMutex.RUnlock() + goit.EndSession(uid, token) + + if current { + goit.EndSessionCookie(w) + http.Redirect(w, r, "/", http.StatusFound) + return + } + + http.Redirect(w, r, "/user/sessions", http.StatusFound) + return + } + + for i, v := range goit.Sessions[uid] { + data.Sessions = append(data.Sessions, row{ + Index: fmt.Sprint(i), Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime), + Current: util.If(v.Token == ss.Token, "(current)", ""), + }) + } + goit.SessionsMutex.RUnlock() + + if err := goit.Tmpl.ExecuteTemplate(w, "user/sessions", data); err != nil { + log.Println("[/user/login]", err.Error()) + } +}