Goit

Simple and lightweight Git web server
git clone http://git.omkov.net/Goit
Log | Tree | Refs | README | Download

AuthorJakob Wakeling <[email protected]>
Date2023-09-28 05:04:37
Commit6fb9830bd1c4a268ef12cc400d946e08260f1dcf
Parent68784b660505ad75a691be1e80ac2cc82788a37e

Allow users to revoke their sessions

Diffstat

M res/user/sessions.html | 2 +-
M src/auth.go | 43 ++++++++++++++++++++++++++++---------------
M src/user.go | 27 +++++++++++++++++++++++----

3 files changed, 52 insertions, 20 deletions

diff --git a/res/user/sessions.html b/res/user/sessions.html
index 2fea513..7dfa729 100644
--- a/res/user/sessions.html
+++ b/res/user/sessions.html
@@ -19,7 +19,7 @@
 					<td>{{.Ip}}</a></td>
 					<td>{{.Seen}}</td>
 					<td>{{.Expiry}}</td>
-					<td><a href="">revoke</a></td>
+					<td><a href="/user/sessions?revoke={{.Index}}">revoke</a></td>
 					<td>{{.Current}}</td>
 				</tr>
 			{{end}}
diff --git a/src/auth.go b/src/auth.go
index d78b2c2..eb63ed9 100644
--- a/src/auth.go
+++ b/src/auth.go
@@ -24,7 +24,7 @@ type Session struct {
 	Seen, Expiry time.Time
 }
 
-var Sessions = map[int64]map[string]Session{}
+var Sessions = map[int64][]Session{}
 
 func NewSession(uid int64, ip string, expiry time.Time) (Session, error) {
 	b := make([]byte, 24)
@@ -33,18 +33,26 @@ func NewSession(uid int64, ip string, expiry time.Time) (Session, error) {
 	}
 
 	if Sessions[uid] == nil {
-		Sessions[uid] = map[string]Session{}
+		Sessions[uid] = []Session{}
 	}
 
 	t := base64.StdEncoding.EncodeToString(b)
-	Sessions[uid][t] = Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry}
-	return Sessions[uid][t], nil
+	s := Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry}
+
+	Sessions[uid] = append(Sessions[uid], s)
+	return s, nil
 }
 
-func EndSession(id int64, token string) {
-	delete(Sessions[id], token)
-	if len(Sessions[id]) == 0 {
-		delete(Sessions, id)
+func EndSession(uid int64, token string) {
+	for i, t := range Sessions[uid] {
+		if t.Token == token {
+			Sessions[uid] = append(Sessions[uid][:i], Sessions[uid][i+1:]...)
+			break
+		}
+	}
+
+	if len(Sessions[uid]) == 0 {
+		delete(Sessions, uid)
 	}
 }
 
@@ -52,9 +60,9 @@ func CleanupSessions() {
 	var n uint64 = 0
 
 	for k, v := range Sessions {
-		for k1, v1 := range v {
+		for _, v1 := range v {
 			if v1.Expiry.Before(time.Now()) {
-				EndSession(k, k1)
+				EndSession(k, v1.Token)
 				n += 1
 			}
 		}
@@ -86,13 +94,18 @@ func GetSessionCookie(r *http.Request) (int64, Session) {
 			return -1, Session{}
 		}
 
-		s := Sessions[id][ss[1]]
-		if s != (Session{}) {
-			s.Seen = time.Now()
-			Sessions[id][ss[1]] = s
+		for i, s := range Sessions[id] {
+			if ss[1] == s.Token {
+				if s != (Session{}) {
+					s.Seen = time.Now()
+					Sessions[id][i] = s
+				}
+
+				return id, s
+			}
 		}
 
-		return id, s
+		return id, Session{}
 	}
 
 	return -1, Session{}
diff --git a/src/user.go b/src/user.go
index 1ddf752..967ab34 100644
--- a/src/user.go
+++ b/src/user.go
@@ -10,6 +10,7 @@ import (
 	"fmt"
 	"log"
 	"net/http"
+	"strconv"
 	"strings"
 	"time"
 
@@ -44,16 +45,34 @@ func HandleUserSessions(w http.ResponseWriter, r *http.Request) {
 
 	_, ss := GetSessionCookie(r)
 
-	type row struct{ Ip, Seen, Expiry, Current string }
+	revoke, err := strconv.ParseInt(r.FormValue("revoke"), 10, 64)
+	if err != nil {
+		revoke = -1
+	}
+	if revoke >= 0 && revoke < int64(len(Sessions[uid])) {
+		current := Sessions[uid][revoke].Token == ss.Token
+		EndSession(uid, Sessions[uid][revoke].Token)
+
+		if current {
+			EndSessionCookie(w)
+			http.Redirect(w, r, "/", http.StatusFound)
+			return
+		}
+
+		http.Redirect(w, r, "/user/sessions", http.StatusFound)
+		return
+	}
+
+	type row struct{ Index, Ip, Seen, Expiry, Current string }
 	data := struct {
 		Title    string
 		Sessions []row
 	}{Title: "User - Sessions"}
 
-	for k, v := range Sessions[uid] {
+	for i, v := range Sessions[uid] {
 		data.Sessions = append(data.Sessions, row{
-			Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime),
-			Current: util.If(k == ss.Token, "(current)", ""),
+			Index: fmt.Sprint(i), Ip: v.Ip, Seen: v.Seen.Format(time.DateTime), Expiry: v.Expiry.Format(time.DateTime),
+			Current: util.If(v.Token == ss.Token, "(current)", ""),
 		})
 	}