feat: add csrf token and audit trails

This commit is contained in:
hanif salafi 2025-04-10 04:28:46 +07:00
parent c2447e619f
commit 3355d9b4a0
9 changed files with 216 additions and 12 deletions

2
.gitignore vendored
View File

@ -1,3 +1,3 @@
/vendor
debug.log
.idea
/.idea

View File

@ -0,0 +1,11 @@
package entity
import "time"
type CsrfTokenRecords struct {
ID uint `gorm:"primaryKey"`
Token string `gorm:"uniqueIndex;size:255"`
Value []byte `gorm:"value"`
ExpireAt time.Time `gorm:"index"`
CreatedAt time.Time
}

View File

@ -95,6 +95,7 @@ func Models() []interface{} {
entity.ArticleNulisAI{},
entity.AuditTrails{},
entity.Cities{},
entity.CsrfTokenRecords{},
entity.CustomStaticPages{},
entity.Districts{},
entity.Feedbacks{},

View File

@ -6,6 +6,7 @@ import (
"go-humas-be/app/database/entity"
utilSvc "go-humas-be/utils/service"
"gorm.io/gorm"
"log"
"time"
)
@ -40,3 +41,16 @@ func AuditTrailsMiddleware(db *gorm.DB) fiber.Handler {
return err
}
}
func StartAuditTrailCleanup(db *gorm.DB, retention int) {
go func() {
for {
time.Sleep(24 * time.Hour)
cutoff := time.Now().AddDate(0, 0, retention)
db.Where("created_at < ?", cutoff).Delete(&entity.AuditTrails{})
log.Printf("Audit Trail Cleanup at: %s", cutoff)
}
}()
}

View File

@ -0,0 +1,79 @@
package middleware
import (
"fmt"
"go-humas-be/app/database/entity"
"gorm.io/gorm"
"log"
"time"
)
type PostgresStorage struct {
DB *gorm.DB
}
func (s *PostgresStorage) Get(key string) ([]byte, error) {
log.Printf("CSRF Storage: Get token %s", key)
var record entity.CsrfTokenRecords
result := s.DB.Where("token = ?", key).First(&record)
if result.Error != nil {
log.Printf("CSRF Storage Get error: %v for token: %s", result.Error, key)
return nil, result.Error
}
if record.ExpireAt.Before(time.Now()) {
log.Printf("CSRF token %s is expired", key)
return nil, fmt.Errorf("CSRF token is expired")
}
return record.Value, nil
}
func (s *PostgresStorage) Set(key string, value []byte, exp time.Duration) error {
log.Printf("CSRF Storage: Setting token %s with expiration %v", key, exp)
// Calculate expiration time
expireAt := time.Now().Add(exp)
// Try to update existing record first
result := s.DB.Model(&entity.CsrfTokenRecords{}).
Where("token = ?", key).
Updates(map[string]interface{}{
"expire_at": expireAt,
})
// If no rows were affected (not found), create a new record
if result.RowsAffected == 0 {
record := entity.CsrfTokenRecords{
Token: key,
Value: value,
ExpireAt: expireAt,
CreatedAt: time.Now(),
}
if err := s.DB.Create(&record).Error; err != nil {
log.Printf("CSRF Storage: Error saving token: %v", err)
return err
}
} else if result.Error != nil {
log.Printf("CSRF Storage: Error updating token: %v", result.Error)
return result.Error
}
log.Printf("CSRF Storage: Successfully saved/updated token")
return nil
}
func (s *PostgresStorage) Delete(key string) error {
return s.DB.Where("token = ?", key).Delete(&entity.CsrfTokenRecords{}).Error
}
func (s *PostgresStorage) Reset() error {
return s.DB.Where("expire_at < ?", time.Now()).Delete(&entity.CsrfTokenRecords{}).Error
}
func (s *PostgresStorage) Close() error {
return nil
}

View File

@ -1,9 +1,12 @@
package middleware
import (
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/session"
"go-humas-be/app/database"
"go-humas-be/config/config"
"go-humas-be/utils"
utilsSvc "go-humas-be/utils"
"log"
"time"
"github.com/gofiber/fiber/v2"
@ -13,8 +16,7 @@ import (
"github.com/gofiber/fiber/v2/middleware/monitor"
"github.com/gofiber/fiber/v2/middleware/pprof"
"github.com/gofiber/fiber/v2/middleware/recover"
auditTrails "go-humas-be/config/middleware"
"github.com/gofiber/fiber/v2/utils"
)
// Middleware is a struct that contains all the middleware functions
@ -35,26 +37,26 @@ func (m *Middleware) Register(db *database.Database) {
// Add Extra Middlewares
m.App.Use(limiter.New(limiter.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Limiter.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Limiter.Enable),
Max: m.Cfg.Middleware.Limiter.Max,
Expiration: m.Cfg.Middleware.Limiter.Expiration * time.Second,
}))
m.App.Use(compress.New(compress.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Compress.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Compress.Enable),
Level: m.Cfg.Middleware.Compress.Level,
}))
m.App.Use(recover.New(recover.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Recover.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Recover.Enable),
}))
m.App.Use(pprof.New(pprof.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Pprof.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Pprof.Enable),
}))
m.App.Use(cors.New(cors.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Cors.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Cors.Enable),
AllowOrigins: "*",
AllowMethods: "HEAD, GET, POST, PUT, DELETE, OPTION, PATCH",
AllowHeaders: "Origin, Content-Type, Accept, Accept-Language, Authorization, X-Requested-With, Access-Control-Request-Method, Access-Control-Request-Headers",
@ -63,7 +65,64 @@ func (m *Middleware) Register(db *database.Database) {
MaxAge: 12,
}))
m.App.Use(auditTrails.AuditTrailsMiddleware(db.DB))
//===============================
// CSRF CONFIG
//===============================
// Custom storage for CSRF
csrfSessionStorage := &PostgresStorage{
DB: db.DB,
}
// Store initialization for session
store := session.New(session.Config{
Storage: csrfSessionStorage,
})
m.App.Use(func(c *fiber.Ctx) error {
sess, err := store.Get(c)
if err != nil {
return err
}
c.Locals("session", sess)
return c.Next()
})
// Cleanup the expired token
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if err := csrfSessionStorage.Reset(); err != nil {
log.Printf("Error cleaning up expired CSRF tokens: %v", err)
}
}
}()
m.App.Use(csrf.New(csrf.Config{
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Csrf.Enable),
KeyLookup: "header:" + csrf.HeaderName,
CookieName: "csrf_",
CookieSameSite: "Lax",
CookieSecure: false,
CookieSessionOnly: true,
CookieHTTPOnly: true,
Expiration: 1 * time.Hour,
KeyGenerator: utils.UUIDv4,
ContextKey: "csrf",
ErrorHandler: func(c *fiber.Ctx, err error) error {
return utilsSvc.CsrfErrorHandler(c, err)
},
Extractor: csrf.CsrfFromHeader(csrf.HeaderName),
Session: store,
SessionKey: "fiber.csrf.token",
}))
//===============================
m.App.Use(AuditTrailsMiddleware(db.DB))
StartAuditTrailCleanup(db.DB, m.Cfg.Middleware.AuditTrails.Retention)
//m.App.Use(filesystem.New(filesystem.Config{
// Next: utils.IsEnabled(m.Cfg.Middleware.FileSystem.Enable),
@ -72,7 +131,32 @@ func (m *Middleware) Register(db *database.Database) {
// MaxAge: m.Cfg.Middleware.FileSystem.MaxAge,
//}))
// ==================================================
m.App.Get(m.Cfg.Middleware.Monitor.Path, monitor.New(monitor.Config{
Next: utils.IsEnabled(m.Cfg.Middleware.Monitor.Enable),
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Monitor.Enable),
}))
// Route for generate CSRF token
m.App.Get("/csrf-token", func(c *fiber.Ctx) error {
// Retrieve CSRF token from Fiber's middleware context
token, ok := c.Locals("csrf").(string)
//c.Context().VisitUserValues(func(key []byte, value interface{}) {
// log.Printf("Local Key: %s, Value: %v", key, value)
//})
if !ok || token == "" {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"success": false,
"code": 500,
"messages": []string{"Failed to retrieve CSRF token"},
})
}
return c.JSON(fiber.Map{
"success": true,
"csrf_token": token,
})
})
}

View File

@ -73,8 +73,13 @@ type middleware = struct {
Expiration time.Duration `toml:"expiration_seconds"`
}
Csrf struct {
Enable bool
}
AuditTrails struct {
Enable bool
Retention int
}
}

View File

@ -53,8 +53,12 @@ enable = true
max = 20
expiration_seconds = 60
[middleware.csrf]
enable = true
[middleware.audittrails]
enable = true
retention = 30
[keycloak]
endpoint = "http://38.47.180.165:8008"

View File

@ -9,3 +9,9 @@ func IsEnabled(key bool) func(c *fiber.Ctx) bool {
return func(c *fiber.Ctx) bool { return true }
}
func CsrfErrorHandler(c *fiber.Ctx, err error) error {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"error": "CSRF protection: " + err.Error(),
})
}