Goit

Simple and lightweight Git web server
git clone https://git.omkov.net/Goit
git clone [email protected]:Goit
Log | Tree | Refs | README | Download

Goit/src/goit/user.go (142 lines, 3.3 KiB) -rw-r--r-- blame download

0123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
// Copyright (C) 2023, Jakob Wakeling
// All rights reserved.

package goit

import (
	"database/sql"
	"errors"
	"fmt"
	"net/http"
	"strings"
)

type User struct {
	Id                int64      `json:"id"`
	Name              string     `json:"name"`
	FullName          string     `json:"name_full"`
	Pass              []byte     `json:"pass"`
	PassAlgo          string     `json:"pass_algo"`
	Salt              []byte     `json:"salt"`
	IsAdmin           bool       `json:"is_admin"`
	DefaultVisibility Visibility `json:"default_visibility"`
}

func HandleUserLogout(w http.ResponseWriter, r *http.Request) {
	id, s := GetSessionCookie(r)
	EndSession(id, s.Token)
	EndSessionCookie(w)
	http.Redirect(w, r, "/", http.StatusFound)
}

func GetUsers() ([]User, error) {
	users := []User{}

	rows, err := db.Query("SELECT id, name, name_full, pass, pass_algo, salt, is_admin, default_visibility FROM users")
	if err != nil {
		return nil, err
	}

	defer rows.Close()

	for rows.Next() {
		u := User{}
		if err := rows.Scan(
			&u.Id, &u.Name, &u.FullName, &u.Pass, &u.PassAlgo, &u.Salt, &u.IsAdmin, &u.DefaultVisibility,
		); err != nil {
			return nil, err
		}

		users = append(users, u)
	}

	if err := rows.Err(); err != nil {
		return nil, err
	}

	return users, nil
}

func GetUser(id int64) (*User, error) {
	u := User{}

	if err := db.QueryRow(
		"SELECT id, name, name_full, pass, pass_algo, salt, is_admin, default_visibility FROM users WHERE id = ?", id,
	).Scan(&u.Id, &u.Name, &u.FullName, &u.Pass, &u.PassAlgo, &u.Salt, &u.IsAdmin, &u.DefaultVisibility); err != nil {
		if !errors.Is(err, sql.ErrNoRows) {
			return nil, fmt.Errorf("[SELECT:user] %w", err)
		} else {
			return nil, nil
		}
	} else {
		return &u, nil
	}
}

func GetUserByName(name string) (*User, error) {
	u := &User{}

	err := db.QueryRow(
		"SELECT id, name, name_full, pass, pass_algo, salt, is_admin, default_visibility FROM users WHERE name = ?",
		strings.ToLower(name),
	).Scan(&u.Id, &u.Name, &u.FullName, &u.Pass, &u.PassAlgo, &u.Salt, &u.IsAdmin, &u.DefaultVisibility)
	if errors.Is(err, sql.ErrNoRows) {
		return nil, nil
	} else if err != nil {
		return nil, err
	}

	return u, nil
}

func UserExists(name string) (bool, error) {
	if err := db.QueryRow("SELECT name FROM users WHERE name = ?", strings.ToLower(name)).Scan(&name); err != nil {
		if !errors.Is(err, sql.ErrNoRows) {
			return false, err
		} else {
			return false, nil
		}
	} else {
		return true, nil
	}
}

func CreateUser(user User) error {
	if _, err := db.Exec(
		`INSERT INTO users (name, name_full, pass, pass_algo, salt, is_admin, default_visibility)
			VALUES (?, ?, ?, ?, ?, ?, ?)`,
		user.Name, user.FullName, user.Pass, user.PassAlgo, user.Salt, user.IsAdmin, user.DefaultVisibility,
	); err != nil {
		return err
	}

	return nil
}

func UpdateUser(uid int64, user User) error {
	if _, err := db.Exec(
		"UPDATE users SET name = ?, name_full = ?, is_admin = ?, default_visibility = ? WHERE id = ?",
		user.Name, user.FullName, user.IsAdmin, user.DefaultVisibility, uid,
	); err != nil {
		return err
	}

	return nil
}

func UpdatePassword(uid int64, password string) error {
	salt, err := Salt()
	if err != nil {
		return err
	}

	if _, err := db.Exec(
		"UPDATE users SET pass = ?, pass_algo = ?, salt = ? WHERE id = ?",
		Hash(password, salt), "argon2", salt, uid,
	); err != nil {
		return err
	}

	return nil
}