521 lines
18 KiB
Go
521 lines
18 KiB
Go
package controllers
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net/mail"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type AuthController struct {
|
|
DB *gorm.DB
|
|
IsUniqueConstraint func(error) bool
|
|
AppBaseURL func() string
|
|
MailConfigured func() bool
|
|
SendPasswordResetEmail func(recipientEmail, resetURL string) error
|
|
SendEmailVerificationMail func(recipientEmail, verifyURL string) error
|
|
}
|
|
|
|
type registerRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
ConfirmPassword string `json:"confirmPassword"`
|
|
}
|
|
|
|
type loginRequest struct {
|
|
Email string `json:"email"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type forgotPasswordRequest struct {
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
type resetPasswordRequest struct {
|
|
Token string `json:"token"`
|
|
Password string `json:"password"`
|
|
ConfirmPassword string `json:"confirmPassword"`
|
|
}
|
|
|
|
type resendVerificationRequest struct {
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
const sessionCookieName = "bag_exchange_session"
|
|
const sessionDurationDays = 7
|
|
const passwordResetDurationMinutes = 60
|
|
const emailVerificationDurationHours = 24
|
|
|
|
func (a *AuthController) Register(c fiber.Ctx) error {
|
|
var req registerRequest
|
|
if err := json.Unmarshal(c.Body(), &req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(req.Email))
|
|
if _, err := mail.ParseAddress(email); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_email"})
|
|
}
|
|
|
|
if req.Password != req.ConfirmPassword {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_mismatch"})
|
|
}
|
|
|
|
if !isStrongPassword(req.Password) {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_too_weak"})
|
|
}
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Printf("unable to hash password: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
|
|
}
|
|
|
|
err = a.DB.Table("users").Create(map[string]any{
|
|
"email": email,
|
|
"password_hash": string(passwordHash),
|
|
"email_verified": 0,
|
|
}).Error
|
|
if err != nil {
|
|
if a.IsUniqueConstraint != nil && a.IsUniqueConstraint(err) {
|
|
return c.Status(fiber.StatusConflict).JSON(map[string]string{"error": "email_exists"})
|
|
}
|
|
log.Printf("unable to register user: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
|
|
}
|
|
|
|
var createdUser struct {
|
|
ID int64 `gorm:"column:id"`
|
|
}
|
|
err = a.DB.Table("users").Select("id").Where("email = ?", email).Take(&createdUser).Error
|
|
if err != nil {
|
|
log.Printf("unable to read new user id: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
|
|
}
|
|
userID := createdUser.ID
|
|
|
|
verifyToken, err := createEmailVerificationToken(a.DB, userID)
|
|
if err != nil {
|
|
log.Printf("unable to create verify token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_register"})
|
|
}
|
|
|
|
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendEmailVerificationMail != nil {
|
|
verifyURL := fmt.Sprintf("%s/auth/verify-email?token=%s", a.AppBaseURL(), verifyToken)
|
|
if err := a.SendEmailVerificationMail(email, verifyURL); err != nil {
|
|
log.Printf("unable to send verification email: %v", err)
|
|
}
|
|
}
|
|
|
|
return c.Status(fiber.StatusCreated).JSON(map[string]string{"status": "registered_pending_verification"})
|
|
}
|
|
|
|
func (a *AuthController) Login(c fiber.Ctx) error {
|
|
var req loginRequest
|
|
if err := json.Unmarshal(c.Body(), &req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(req.Email))
|
|
if _, err := mail.ParseAddress(email); err != nil || req.Password == "" {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
|
|
}
|
|
|
|
var loginUser struct {
|
|
ID int64 `gorm:"column:id"`
|
|
PasswordHash string `gorm:"column:password_hash"`
|
|
EmailVerified int `gorm:"column:email_verified"`
|
|
}
|
|
err := a.DB.Table("users").
|
|
Select("id, password_hash, email_verified").
|
|
Where("email = ?", email).
|
|
Take(&loginUser).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
|
|
}
|
|
log.Printf("unable to fetch user for login: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
|
|
}
|
|
userID := loginUser.ID
|
|
passwordHash := loginUser.PasswordHash
|
|
emailVerified := loginUser.EmailVerified
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(map[string]string{"error": "invalid_credentials"})
|
|
}
|
|
if emailVerified == 0 {
|
|
return c.Status(fiber.StatusForbidden).JSON(map[string]string{"error": "email_not_verified"})
|
|
}
|
|
|
|
sessionToken, err := generateSessionToken()
|
|
if err != nil {
|
|
log.Printf("unable to generate session token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
|
|
}
|
|
|
|
expiresAt := time.Now().AddDate(0, 0, sessionDurationDays).Unix()
|
|
err = a.DB.Table("sessions").Create(map[string]any{
|
|
"user_id": userID,
|
|
"token_hash": hashToken(sessionToken),
|
|
"expires_at": expiresAt,
|
|
}).Error
|
|
if err != nil {
|
|
log.Printf("unable to create session: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_login"})
|
|
}
|
|
|
|
setSessionCookie(c, sessionToken, expiresAt)
|
|
return c.Status(fiber.StatusOK).JSON(map[string]any{
|
|
"status": "authenticated",
|
|
"email": email,
|
|
})
|
|
}
|
|
|
|
func (a *AuthController) Logout(c fiber.Ctx) error {
|
|
sessionToken := c.Cookies(sessionCookieName)
|
|
if sessionToken != "" {
|
|
if err := a.DB.Table("sessions").Where("token_hash = ?", hashToken(sessionToken)).Delete(nil).Error; err != nil {
|
|
log.Printf("unable to delete session: %v", err)
|
|
}
|
|
}
|
|
clearSessionCookie(c)
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "logged_out"})
|
|
}
|
|
|
|
func (a *AuthController) Me(c fiber.Ctx) error {
|
|
sessionToken := c.Cookies(sessionCookieName)
|
|
if sessionToken == "" {
|
|
return c.Status(fiber.StatusUnauthorized).JSON(map[string]any{"authenticated": false})
|
|
}
|
|
|
|
var sessionUser struct {
|
|
UserID int64 `gorm:"column:user_id"`
|
|
Email string `gorm:"column:email"`
|
|
}
|
|
err := a.DB.Table("sessions").
|
|
Select("users.id AS user_id, users.email").
|
|
Joins("JOIN users ON users.id = sessions.user_id").
|
|
Where("sessions.token_hash = ? AND sessions.expires_at > ?", hashToken(sessionToken), time.Now().Unix()).
|
|
Take(&sessionUser).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
clearSessionCookie(c)
|
|
return c.Status(fiber.StatusUnauthorized).JSON(map[string]any{"authenticated": false})
|
|
}
|
|
log.Printf("unable to resolve session: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_fetch_session"})
|
|
}
|
|
|
|
return c.Status(fiber.StatusOK).JSON(map[string]any{
|
|
"authenticated": true,
|
|
"userId": sessionUser.UserID,
|
|
"email": sessionUser.Email,
|
|
})
|
|
}
|
|
|
|
func (a *AuthController) ForgotPassword(c fiber.Ctx) error {
|
|
var req forgotPasswordRequest
|
|
if err := json.Unmarshal(c.Body(), &req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(req.Email))
|
|
if _, err := mail.ParseAddress(email); err != nil {
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
var userID int64
|
|
err := a.DB.Table("users").Select("id").Where("email = ?", email).Take(&userID).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
log.Printf("unable to fetch user for forgot-password: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
|
|
}
|
|
|
|
resetToken, err := generateSessionToken()
|
|
if err != nil {
|
|
log.Printf("unable to generate reset token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
|
|
}
|
|
expiresAt := time.Now().Add(time.Minute * passwordResetDurationMinutes).Unix()
|
|
|
|
err = a.DB.Table("password_reset_tokens").Create(map[string]any{
|
|
"user_id": userID,
|
|
"token_hash": hashToken(resetToken),
|
|
"expires_at": expiresAt,
|
|
}).Error
|
|
if err != nil {
|
|
log.Printf("unable to persist reset token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
|
|
}
|
|
|
|
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendPasswordResetEmail != nil {
|
|
resetURL := fmt.Sprintf("%s/reset-password?token=%s", a.AppBaseURL(), resetToken)
|
|
if err := a.SendPasswordResetEmail(email, resetURL); err != nil {
|
|
log.Printf("unable to send reset email: %v", err)
|
|
}
|
|
} else {
|
|
log.Printf("smtp not configured: skip password reset email")
|
|
}
|
|
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
func (a *AuthController) ResendVerification(c fiber.Ctx) error {
|
|
var req resendVerificationRequest
|
|
if err := json.Unmarshal(c.Body(), &req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
|
|
}
|
|
|
|
email := strings.ToLower(strings.TrimSpace(req.Email))
|
|
if _, err := mail.ParseAddress(email); err != nil {
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
var resendUser struct {
|
|
ID int64 `gorm:"column:id"`
|
|
EmailVerified int `gorm:"column:email_verified"`
|
|
}
|
|
err := a.DB.Table("users").Select("id, email_verified").Where("email = ?", email).Take(&resendUser).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
log.Printf("unable to fetch user for resend verification: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
|
|
}
|
|
userID := resendUser.ID
|
|
verified := resendUser.EmailVerified
|
|
if verified != 0 {
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
verifyToken, err := createEmailVerificationToken(a.DB, userID)
|
|
if err != nil {
|
|
log.Printf("unable to create verify token on resend: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_process"})
|
|
}
|
|
if a.MailConfigured != nil && a.MailConfigured() && a.AppBaseURL != nil && a.SendEmailVerificationMail != nil {
|
|
verifyURL := fmt.Sprintf("%s/auth/verify-email?token=%s", a.AppBaseURL(), verifyToken)
|
|
if err := a.SendEmailVerificationMail(email, verifyURL); err != nil {
|
|
log.Printf("unable to resend verification email: %v", err)
|
|
}
|
|
}
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
func (a *AuthController) ResetPassword(c fiber.Ctx) error {
|
|
var req resetPasswordRequest
|
|
if err := json.Unmarshal(c.Body(), &req); err != nil {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_payload"})
|
|
}
|
|
|
|
token := strings.TrimSpace(req.Token)
|
|
if token == "" {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_token"})
|
|
}
|
|
|
|
if req.Password != req.ConfirmPassword {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_mismatch"})
|
|
}
|
|
|
|
if !isStrongPassword(req.Password) {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "password_too_weak"})
|
|
}
|
|
|
|
var resetTokenRow struct {
|
|
ID int64 `gorm:"column:id"`
|
|
UserID int64 `gorm:"column:user_id"`
|
|
}
|
|
err := a.DB.Table("password_reset_tokens").
|
|
Select("id, user_id").
|
|
Where("token_hash = ? AND used_at IS NULL AND expires_at > ?", hashToken(token), time.Now().Unix()).
|
|
Order("id DESC").
|
|
Take(&resetTokenRow).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return c.Status(fiber.StatusBadRequest).JSON(map[string]string{"error": "invalid_or_expired_token"})
|
|
}
|
|
log.Printf("unable to fetch reset token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
resetID := resetTokenRow.ID
|
|
userID := resetTokenRow.UserID
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
log.Printf("unable to hash password in reset: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
|
|
tx := a.DB.Begin()
|
|
if tx.Error != nil {
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
|
|
if err := tx.Table("users").Where("id = ?", userID).Updates(map[string]any{
|
|
"password_hash": string(passwordHash),
|
|
"updated_at": gorm.Expr("datetime('now')"),
|
|
}).Error; err != nil {
|
|
_ = tx.Rollback().Error
|
|
log.Printf("unable to update password: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
if err := tx.Table("password_reset_tokens").Where("id = ?", resetID).Update("used_at", time.Now().Unix()).Error; err != nil {
|
|
_ = tx.Rollback().Error
|
|
log.Printf("unable to mark reset token used: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
if err := tx.Table("sessions").Where("user_id = ?", userID).Delete(nil).Error; err != nil {
|
|
_ = tx.Rollback().Error
|
|
log.Printf("unable to delete sessions after reset: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
if err := tx.Commit().Error; err != nil {
|
|
log.Printf("unable to commit reset tx: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).JSON(map[string]string{"error": "unable_to_reset_password"})
|
|
}
|
|
|
|
clearSessionCookie(c)
|
|
return c.Status(fiber.StatusOK).JSON(map[string]string{"status": "password_reset"})
|
|
}
|
|
|
|
func (a *AuthController) VerifyEmail(c fiber.Ctx) error {
|
|
token := strings.TrimSpace(c.Query("token"))
|
|
if token == "" {
|
|
return c.Status(fiber.StatusBadRequest).SendString("Invalid verification token.")
|
|
}
|
|
|
|
var verifyTokenRow struct {
|
|
ID int64 `gorm:"column:id"`
|
|
UserID int64 `gorm:"column:user_id"`
|
|
}
|
|
err := a.DB.Table("email_verification_tokens").
|
|
Select("id, user_id").
|
|
Where("token_hash = ? AND used_at IS NULL AND expires_at > ?", hashToken(token), time.Now().Unix()).
|
|
Order("id DESC").
|
|
Take(&verifyTokenRow).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return c.Status(fiber.StatusBadRequest).SendString("Verification link is invalid or expired.")
|
|
}
|
|
log.Printf("unable to validate verification token: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
|
|
}
|
|
tokenID := verifyTokenRow.ID
|
|
userID := verifyTokenRow.UserID
|
|
|
|
tx := a.DB.Begin()
|
|
if tx.Error != nil {
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
|
|
}
|
|
if err := tx.Table("users").Where("id = ?", userID).Updates(map[string]any{
|
|
"email_verified": 1,
|
|
"updated_at": gorm.Expr("datetime('now')"),
|
|
}).Error; err != nil {
|
|
_ = tx.Rollback().Error
|
|
log.Printf("unable to mark email verified: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
|
|
}
|
|
if err := tx.Table("email_verification_tokens").Where("id = ?", tokenID).Update("used_at", time.Now().Unix()).Error; err != nil {
|
|
_ = tx.Rollback().Error
|
|
log.Printf("unable to mark verify token used: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
|
|
}
|
|
if err := tx.Commit().Error; err != nil {
|
|
log.Printf("unable to commit verify email tx: %v", err)
|
|
return c.Status(fiber.StatusInternalServerError).SendString("Unable to verify email right now.")
|
|
}
|
|
|
|
return c.SendString("Email verified successfully. You can now log in.")
|
|
}
|
|
|
|
func setSessionCookie(c fiber.Ctx, value string, expiresAtUnix int64) {
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
HTTPOnly: true,
|
|
Secure: false,
|
|
SameSite: "Lax",
|
|
Expires: time.Unix(expiresAtUnix, 0),
|
|
MaxAge: 60 * 60 * 24 * sessionDurationDays,
|
|
})
|
|
}
|
|
|
|
func clearSessionCookie(c fiber.Ctx) {
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: sessionCookieName,
|
|
Value: "",
|
|
Path: "/",
|
|
HTTPOnly: true,
|
|
Secure: false,
|
|
SameSite: "Lax",
|
|
Expires: time.Unix(0, 0),
|
|
MaxAge: -1,
|
|
})
|
|
}
|
|
|
|
func hashToken(value string) string {
|
|
sum := sha256.Sum256([]byte(value))
|
|
return hex.EncodeToString(sum[:])
|
|
}
|
|
|
|
func generateSessionToken() (string, error) {
|
|
raw := make([]byte, 32)
|
|
if _, err := rand.Read(raw); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(raw), nil
|
|
}
|
|
|
|
func isStrongPassword(value string) bool {
|
|
if len(value) < 8 {
|
|
return false
|
|
}
|
|
|
|
hasLetter := false
|
|
hasDigit := false
|
|
for _, r := range value {
|
|
if unicode.IsLetter(r) {
|
|
hasLetter = true
|
|
}
|
|
if unicode.IsDigit(r) {
|
|
hasDigit = true
|
|
}
|
|
}
|
|
return hasLetter && hasDigit
|
|
}
|
|
|
|
func createEmailVerificationToken(db *gorm.DB, userID int64) (string, error) {
|
|
token, err := generateSessionToken()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
expiresAt := time.Now().Add(time.Hour * emailVerificationDurationHours).Unix()
|
|
err = db.Table("email_verification_tokens").Create(map[string]any{
|
|
"user_id": userID,
|
|
"token_hash": hashToken(token),
|
|
"expires_at": expiresAt,
|
|
}).Error
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|