Goit

Simple and lightweight Git web server
git clone http://git.omkov.net/Goit
Log | Tree | Refs | README | Download

Goit/src/goit/auth.go (213 lines, 4.5 KiB) -rw-r--r-- blame download

0123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
// Copyright (C) 2023, Jakob Wakeling
// All rights reserved.

package goit

import (
	"crypto/rand"
	"encoding/base64"
	"fmt"
	"log"
	"net/http"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/Jamozed/Goit/src/util"
	"golang.org/x/crypto/argon2"
)

type Session struct {
	Token, Ip    string
	Seen, Expiry time.Time
}

var Sessions = map[int64][]Session{}
var SessionsMutex = sync.RWMutex{}

/* Generate a new user session. */
func NewSession(uid int64, ip string, expiry time.Time) (Session, error) {
	var b = make([]byte, 24)
	if _, err := rand.Read(b); err != nil {
		return Session{}, err
	}

	var t = base64.StdEncoding.EncodeToString(b)
	var s = Session{Token: t, Ip: util.If(Conf.IpSessions, ip, ""), Seen: time.Now(), Expiry: expiry}

	SessionsMutex.Lock()
	util.Debugln("[goit.NewSession] SessionsMutex lock")

	if Sessions[uid] == nil {
		Sessions[uid] = []Session{}
	}

	Sessions[uid] = append(Sessions[uid], s)

	SessionsMutex.Unlock()
	util.Debugln("[goit.EndSession] SessionsMutex unlock")

	return s, nil
}

/* End a user session. */
func EndSession(uid int64, token string) {
	SessionsMutex.Lock()
	util.Debugln("[goit.EndSession] SessionsMutex lock")
	defer SessionsMutex.Unlock()
	defer util.Debugln("[goit.EndSession] SessionsMutex unlock")

	if Sessions[uid] == nil {
		return
	}

	for i, t := range Sessions[uid] {
		if t.Token == token {
			Sessions[uid] = append(Sessions[uid][:i], Sessions[uid][i+1:]...)
			break
		}
	}

	if len(Sessions[uid]) == 0 {
		delete(Sessions, uid)
	}
}

/* Cleanup expired user sessions. */
func CleanupSessions() {
	var n int = 0

	SessionsMutex.Lock()
	util.Debugln("[goit.CleanupSessions] SessionsMutex lock")

	for uid, v := range Sessions {
		var i = 0
		for _, s := range v {
			if s.Expiry.After(time.Now()) {
				v[i] = s
				i += 1
			}
		}

		n += len(v) - i

		if i == 0 {
			delete(Sessions, uid)
		} else {
			Sessions[uid] = v[:i]
		}
	}

	SessionsMutex.Unlock()
	util.Debugln("[goit.CleanupSessions] SessionsMutex unlock")

	if n > 0 {
		log.Println("[Cleanup] cleaned up", n, "expired sessions")
	}
}

/* Set a user session cookie. */
func SetSessionCookie(w http.ResponseWriter, uid int64, s Session) {
	c := &http.Cookie{
		Name: "session", Value: fmt.Sprint(uid) + "." + s.Token, Path: "/", Expires: s.Expiry,
		Secure: util.If(Conf.UsesHttps, true, false), HttpOnly: true, SameSite: http.SameSiteLaxMode,
	}

	if err := c.Valid(); err != nil {
		log.Println("[Cookie]", err.Error())
	}

	http.SetCookie(w, c)
}

/* Get a user session cookie if one is present. */
func GetSessionCookie(r *http.Request) (int64, Session) {
	if c := util.Cookie(r, "session"); c != nil {
		ss := strings.SplitN(c.Value, ".", 2)
		if len(ss) != 2 {
			return -1, Session{}
		}

		uid, err := strconv.ParseInt(ss[0], 10, 64)
		if err != nil {
			return -1, Session{}
		}

		SessionsMutex.Lock()
		util.Debugln("[goit.GetSessionCookie] SessionsMutex lock")
		defer SessionsMutex.Unlock()
		defer util.Debugln("[goit.GetSessionCookie] SessionsMutex unlock")

		for i, s := range Sessions[uid] {
			if ss[1] == s.Token {
				if s != (Session{}) {
					s.Seen = time.Now()
					Sessions[uid][i] = s
				}

				return uid, s
			}
		}

		return uid, Session{}
	}

	return -1, Session{}
}

/* End the current user session cookie. */
func EndSessionCookie(w http.ResponseWriter) {
	http.SetCookie(w, &http.Cookie{Name: "session", Path: "/", MaxAge: -1})
}

/* Authenticate a user session, returns auth, user, error. */
func Auth(w http.ResponseWriter, r *http.Request, renew bool) (bool, *User, error) {
	uid, s := GetSessionCookie(r)
	if s == (Session{}) {
		return false, nil, nil
	}

	/* Attempt to get the user associated with the session UID */
	user, err := GetUser(uid)
	if err != nil {
		return false, nil, fmt.Errorf("[auth] %w", err)
	}

	/* End invalid and expired sessions */
	if user == nil || s.Expiry.Before(time.Now()) {
		EndSession(uid, s.Token)
		return false, nil, nil
	}

	/* Renew the session if appropriate */
	if renew && time.Until(s.Expiry) < 24*time.Hour {
		ip := Ip(r)

		s1, err := NewSession(uid, ip, time.Now().Add(2*24*time.Hour))
		if err != nil {
			log.Println("[auth/renew]", err.Error())
		} else {
			SetSessionCookie(w, uid, s1)
			EndSession(uid, s.Token)
		}
	}

	return true, user, nil
}

/* Hash a password with a salt using Argon2. */
func Hash(pass string, salt []byte) []byte {
	return argon2.IDKey([]byte(pass), salt, 3, 64*1024, 4, 32)
}

/* Generate a random Base64 salt. */
func Salt() ([]byte, error) {
	b := make([]byte, 16)
	if _, err := rand.Read(b); err != nil {
		return nil, err
	}

	return b, nil
}