Files
BagExchange/controllers/auth_controller.go
2026-02-19 20:08:06 +01:00

521 lines
18 KiB
Go

package controllers
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net/mail"
"strings"
"time"
"unicode"
"github.com/gofiber/fiber/v3"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type AuthController struct {
DB *gorm.DB
IsUniqueConstraint func(error) bool
AppBaseURL func() string
MailConfigured func() bool
SendPasswordResetEmail func(recipientEmail, resetURL string) error
SendEmailVerificationMail func(recipientEmail, verifyURL string) error
}
type registerRequest struct {
Email string `json:"email"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
}
type loginRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
type forgotPasswordRequest struct {
Email string `json:"email"`
}
type resetPasswordRequest struct {
Token string `json:"token"`
Password string `json:"password"`
ConfirmPassword string `json:"confirmPassword"`
}
type resendVerificationRequest struct {
Email string `json:"email"`
}
const sessionCookieName = "bag_exchange_session"
const sessionDurationDays = 7
const passwordResetDurationMinutes = 60
const emailVerificationDurationHours = 24
func (a *AuthController) Register(c fiber.Ctx) error {
var req registerRequest
if err := json.Unmarshal(c.Body(), &req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
}
email := strings.ToLower(strings.TrimSpace(req.Email))
if _, err := mail.ParseAddress(email); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_email"})
}
if req.Password != req.ConfirmPassword {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_mismatch"})
}
if !isStrongPassword(req.Password) {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_too_weak"})
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
log.Printf("unable to hash password: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
}
err = a.DB.Table("users").Create(map[string]any{
"email": email,
"password_hash": string(passwordHash),
"email_verified": 0,
}).Error
if err != nil {
if a.IsUniqueConstraint != nil && a.IsUniqueConstraint(err) {
return c.Status(fiber.StatusConflict).JSON(map[string]string{"error": "email_exists"})
}
log.Printf("unable to register user: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
}
var createdUser struct {
ID int64 `gorm:"column:id"`
}
err = a.DB.Table("users").Select("id").Where("email = ?", email).Take(&createdUser).Error
if err != nil {
log.Printf("unable to read new user id: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
}
userID := createdUser.ID
verifyToken, err := createEmailVerificationToken(a.DB, userID)
if err != nil {
log.Printf("unable to create verify token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
}
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendEmailVerificationMail != nil {
verifyURL := fmt.Sprintf("%s/auth/verify-email?token=%s", a.AppBaseURL(), verifyToken)
if err := a.SendEmailVerificationMail(email, verifyURL); err != nil {
log.Printf("unable to send verification email: %v", err)
}
}
return c.Status(fiber.StatusCreated).JSON(map[string]string{"status": "registered_pending_verification"})
}
func (a *AuthController) Login(c fiber.Ctx) error {
var req loginRequest
if err := json.Unmarshal(c.Body(), &req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
}
email := strings.ToLower(strings.TrimSpace(req.Email))
if _, err := mail.ParseAddress(email); err != nil || req.Password == "" {
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
}
var loginUser struct {
ID int64 `gorm:"column:id"`
PasswordHash string `gorm:"column:password_hash"`
EmailVerified int `gorm:"column:email_verified"`
}
err := a.DB.Table("users").
Select("id, password_hash, email_verified").
Where("email = ?", email).
Take(&loginUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
}
log.Printf("unable to fetch user for login: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
}
userID := loginUser.ID
passwordHash := loginUser.PasswordHash
emailVerified := loginUser.EmailVerified
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
}
if emailVerified == 0 {
return c.Status(fiber.StatusForbidden).JSON(map[string]string{"error": "email_not_verified"})
}
sessionToken, err := generateSessionToken()
if err != nil {
log.Printf("unable to generate session token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
}
expiresAt := time.Now().AddDate(0, 0, sessionDurationDays).Unix()
err = a.DB.Table("sessions").Create(map[string]any{
"user_id": userID,
"token_hash": hashToken(sessionToken),
"expires_at": expiresAt,
}).Error
if err != nil {
log.Printf("unable to create session: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
}
setSessionCookie(c, sessionToken, expiresAt)
return c.Status(fiber.StatusOK).JSON(map[string]any{
"status": "authenticated",
"email": email,
})
}
func (a *AuthController) Logout(c fiber.Ctx) error {
sessionToken := c.Cookies(sessionCookieName)
if sessionToken != "" {
if err := a.DB.Table("sessions").Where("token_hash = ?", hashToken(sessionToken)).Delete(nil).Error; err != nil {
log.Printf("unable to delete session: %v", err)
}
}
clearSessionCookie(c)
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "logged_out"})
}
func (a *AuthController) Me(c fiber.Ctx) error {
sessionToken := c.Cookies(sessionCookieName)
if sessionToken == "" {
return c.Status(fiber.StatusUnauthorized).JSON(map[string]any{"authenticated": false})
}
var sessionUser struct {
UserID int64 `gorm:"column:user_id"`
Email string `gorm:"column:email"`
}
err := a.DB.Table("sessions").
Select("users.id AS user_id, users.email").
Joins("JOIN users ON users.id = sessions.user_id").
Where("sessions.token_hash = ? AND sessions.expires_at > ?", hashToken(sessionToken), time.Now().Unix()).
Take(&sessionUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
clearSessionCookie(c)
return c.Status(fiber.StatusUnauthorized).JSON(map[string]any{"authenticated": false})
}
log.Printf("unable to resolve session: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_fetch_session"})
}
return c.Status(fiber.StatusOK).JSON(map[string]any{
"authenticated": true,
"userId": sessionUser.UserID,
"email": sessionUser.Email,
})
}
func (a *AuthController) ForgotPassword(c fiber.Ctx) error {
var req forgotPasswordRequest
if err := json.Unmarshal(c.Body(), &req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
}
email := strings.ToLower(strings.TrimSpace(req.Email))
if _, err := mail.ParseAddress(email); err != nil {
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
var userID int64
err := a.DB.Table("users").Select("id").Where("email = ?", email).Take(&userID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
log.Printf("unable to fetch user for forgot-password: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
}
resetToken, err := generateSessionToken()
if err != nil {
log.Printf("unable to generate reset token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
}
expiresAt := time.Now().Add(time.Minute * passwordResetDurationMinutes).Unix()
err = a.DB.Table("password_reset_tokens").Create(map[string]any{
"user_id": userID,
"token_hash": hashToken(resetToken),
"expires_at": expiresAt,
}).Error
if err != nil {
log.Printf("unable to persist reset token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
}
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendPasswordResetEmail != nil {
resetURL := fmt.Sprintf("%s/reset-password?token=%s", a.AppBaseURL(), resetToken)
if err := a.SendPasswordResetEmail(email, resetURL); err != nil {
log.Printf("unable to send reset email: %v", err)
}
} else {
log.Printf("smtp not configured: skip password reset email")
}
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
func (a *AuthController) ResendVerification(c fiber.Ctx) error {
var req resendVerificationRequest
if err := json.Unmarshal(c.Body(), &req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
}
email := strings.ToLower(strings.TrimSpace(req.Email))
if _, err := mail.ParseAddress(email); err != nil {
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
var resendUser struct {
ID int64 `gorm:"column:id"`
EmailVerified int `gorm:"column:email_verified"`
}
err := a.DB.Table("users").Select("id, email_verified").Where("email = ?", email).Take(&resendUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
log.Printf("unable to fetch user for resend verification: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
}
userID := resendUser.ID
verified := resendUser.EmailVerified
if verified != 0 {
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
verifyToken, err := createEmailVerificationToken(a.DB, userID)
if err != nil {
log.Printf("unable to create verify token on resend: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
}
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendEmailVerificationMail != nil {
verifyURL := fmt.Sprintf("%s/auth/verify-email?token=%s", a.AppBaseURL(), verifyToken)
if err := a.SendEmailVerificationMail(email, verifyURL); err != nil {
log.Printf("unable to resend verification email: %v", err)
}
}
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
}
func (a *AuthController) ResetPassword(c fiber.Ctx) error {
var req resetPasswordRequest
if err := json.Unmarshal(c.Body(), &req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
}
token := strings.TrimSpace(req.Token)
if token == "" {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_token"})
}
if req.Password != req.ConfirmPassword {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_mismatch"})
}
if !isStrongPassword(req.Password) {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_too_weak"})
}
var resetTokenRow struct {
ID int64 `gorm:"column:id"`
UserID int64 `gorm:"column:user_id"`
}
err := a.DB.Table("password_reset_tokens").
Select("id, user_id").
Where("token_hash = ? AND used_at IS NULL AND expires_at > ?", hashToken(token), time.Now().Unix()).
Order("id DESC").
Take(&resetTokenRow).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_or_expired_token"})
}
log.Printf("unable to fetch reset token: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
resetID := resetTokenRow.ID
userID := resetTokenRow.UserID
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
log.Printf("unable to hash password in reset: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
tx := a.DB.Begin()
if tx.Error != nil {
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
if err := tx.Table("users").Where("id = ?", userID).Updates(map[string]any{
"password_hash": string(passwordHash),
"updated_at": gorm.Expr("datetime('now')"),
}).Error; err != nil {
_ = tx.Rollback().Error
log.Printf("unable to update password: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
if err := tx.Table("password_reset_tokens").Where("id = ?", resetID).Update("used_at", time.Now().Unix()).Error; err != nil {
_ = tx.Rollback().Error
log.Printf("unable to mark reset token used: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
if err := tx.Table("sessions").Where("user_id = ?", userID).Delete(nil).Error; err != nil {
_ = tx.Rollback().Error
log.Printf("unable to delete sessions after reset: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
if err := tx.Commit().Error; err != nil {
log.Printf("unable to commit reset tx: %v", err)
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
}
clearSessionCookie(c)
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "password_reset"})
}
func (a *AuthController) VerifyEmail(c fiber.Ctx) error {
token := strings.TrimSpace(c.Query("token"))
if token == "" {
return c.Status(fiber.StatusBadRequest).SendString("Invalid verification token.")
}
var verifyTokenRow struct {
ID int64 `gorm:"column:id"`
UserID int64 `gorm:"column:user_id"`
}
err := a.DB.Table("email_verification_tokens").
Select("id, user_id").
Where("token_hash = ? AND used_at IS NULL AND expires_at > ?", hashToken(token), time.Now().Unix()).
Order("id DESC").
Take(&verifyTokenRow).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return c.Status(fiber.StatusBadRequest).SendString("Verification link is invalid or expired.")
}
log.Printf("unable to validate verification token: %v", err)
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
}
tokenID := verifyTokenRow.ID
userID := verifyTokenRow.UserID
tx := a.DB.Begin()
if tx.Error != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
}
if err := tx.Table("users").Where("id = ?", userID).Updates(map[string]any{
"email_verified": 1,
"updated_at": gorm.Expr("datetime('now')"),
}).Error; err != nil {
_ = tx.Rollback().Error
log.Printf("unable to mark email verified: %v", err)
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
}
if err := tx.Table("email_verification_tokens").Where("id = ?", tokenID).Update("used_at", time.Now().Unix()).Error; err != nil {
_ = tx.Rollback().Error
log.Printf("unable to mark verify token used: %v", err)
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
}
if err := tx.Commit().Error; err != nil {
log.Printf("unable to commit verify email tx: %v", err)
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
}
return c.SendString("Email verified successfully. You can now log in.")
}
func setSessionCookie(c fiber.Ctx, value string, expiresAtUnix int64) {
c.Cookie(&fiber.Cookie{
Name: sessionCookieName,
Value: value,
Path: "/",
HTTPOnly: true,
Secure: false,
SameSite: "Lax",
Expires: time.Unix(expiresAtUnix, 0),
MaxAge: 60 * 60 * 24 * sessionDurationDays,
})
}
func clearSessionCookie(c fiber.Ctx) {
c.Cookie(&fiber.Cookie{
Name: sessionCookieName,
Value: "",
Path: "/",
HTTPOnly: true,
Secure: false,
SameSite: "Lax",
Expires: time.Unix(0, 0),
MaxAge: -1,
})
}
func hashToken(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:])
}
func generateSessionToken() (string, error) {
raw := make([]byte, 32)
if _, err := rand.Read(raw); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(raw), nil
}
func isStrongPassword(value string) bool {
if len(value) < 8 {
return false
}
hasLetter := false
hasDigit := false
for _, r := range value {
if unicode.IsLetter(r) {
hasLetter = true
}
if unicode.IsDigit(r) {
hasDigit = true
}
}
return hasLetter && hasDigit
}
func createEmailVerificationToken(db *gorm.DB, userID int64) (string, error) {
token, err := generateSessionToken()
if err != nil {
return "", err
}
expiresAt := time.Now().Add(time.Hour * emailVerificationDurationHours).Unix()
err = db.Table("email_verification_tokens").Create(map[string]any{
"user_id": userID,
"token_hash": hashToken(token),
"expires_at": expiresAt,
}).Error
if err != nil {
return "", err
}
return token, nil
}