From 3355d9b4a00a91a20288a9c500ca8f7bfaf2ecb0 Mon Sep 17 00:00:00 2001 From: hanif salafi Date: Thu, 10 Apr 2025 04:28:46 +0700 Subject: [PATCH] feat: add csrf token and audit trails --- .gitignore | 2 +- .../entity/csrf_token_records.entity.go | 11 ++ app/database/index.database.go | 1 + .../middleware/audit_trails.middleware.go | 14 +++ app/middleware/csrf.middleware.go | 79 +++++++++++++ app/middleware/register.middleware.go | 104 ++++++++++++++++-- config/config/index.config.go | 7 +- config/toml/config.toml | 4 + utils/index.utils.go | 6 + 9 files changed, 216 insertions(+), 12 deletions(-) create mode 100644 app/database/entity/csrf_token_records.entity.go rename {config => app}/middleware/audit_trails.middleware.go (76%) create mode 100644 app/middleware/csrf.middleware.go diff --git a/.gitignore b/.gitignore index 7abf38c..fc9de7c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /vendor debug.log -.idea \ No newline at end of file +/.idea \ No newline at end of file diff --git a/app/database/entity/csrf_token_records.entity.go b/app/database/entity/csrf_token_records.entity.go new file mode 100644 index 0000000..e5681ca --- /dev/null +++ b/app/database/entity/csrf_token_records.entity.go @@ -0,0 +1,11 @@ +package entity + +import "time" + +type CsrfTokenRecords struct { + ID uint `gorm:"primaryKey"` + Token string `gorm:"uniqueIndex;size:255"` + Value []byte `gorm:"value"` + ExpireAt time.Time `gorm:"index"` + CreatedAt time.Time +} diff --git a/app/database/index.database.go b/app/database/index.database.go index c869254..0b0eea2 100644 --- a/app/database/index.database.go +++ b/app/database/index.database.go @@ -95,6 +95,7 @@ func Models() []interface{} { entity.ArticleNulisAI{}, entity.AuditTrails{}, entity.Cities{}, + entity.CsrfTokenRecords{}, entity.CustomStaticPages{}, entity.Districts{}, entity.Feedbacks{}, diff --git a/config/middleware/audit_trails.middleware.go b/app/middleware/audit_trails.middleware.go similarity index 76% rename from config/middleware/audit_trails.middleware.go rename to app/middleware/audit_trails.middleware.go index 67b3eb5..9f108a7 100644 --- a/config/middleware/audit_trails.middleware.go +++ b/app/middleware/audit_trails.middleware.go @@ -6,6 +6,7 @@ import ( "go-humas-be/app/database/entity" utilSvc "go-humas-be/utils/service" "gorm.io/gorm" + "log" "time" ) @@ -40,3 +41,16 @@ func AuditTrailsMiddleware(db *gorm.DB) fiber.Handler { return err } } + +func StartAuditTrailCleanup(db *gorm.DB, retention int) { + go func() { + for { + time.Sleep(24 * time.Hour) + + cutoff := time.Now().AddDate(0, 0, retention) + db.Where("created_at < ?", cutoff).Delete(&entity.AuditTrails{}) + + log.Printf("Audit Trail Cleanup at: %s", cutoff) + } + }() +} diff --git a/app/middleware/csrf.middleware.go b/app/middleware/csrf.middleware.go new file mode 100644 index 0000000..d518b57 --- /dev/null +++ b/app/middleware/csrf.middleware.go @@ -0,0 +1,79 @@ +package middleware + +import ( + "fmt" + "go-humas-be/app/database/entity" + "gorm.io/gorm" + "log" + "time" +) + +type PostgresStorage struct { + DB *gorm.DB +} + +func (s *PostgresStorage) Get(key string) ([]byte, error) { + log.Printf("CSRF Storage: Get token %s", key) + + var record entity.CsrfTokenRecords + result := s.DB.Where("token = ?", key).First(&record) + + if result.Error != nil { + log.Printf("CSRF Storage Get error: %v for token: %s", result.Error, key) + return nil, result.Error + } + + if record.ExpireAt.Before(time.Now()) { + log.Printf("CSRF token %s is expired", key) + return nil, fmt.Errorf("CSRF token is expired") + } + + return record.Value, nil +} + +func (s *PostgresStorage) Set(key string, value []byte, exp time.Duration) error { + log.Printf("CSRF Storage: Setting token %s with expiration %v", key, exp) + + // Calculate expiration time + expireAt := time.Now().Add(exp) + + // Try to update existing record first + result := s.DB.Model(&entity.CsrfTokenRecords{}). + Where("token = ?", key). + Updates(map[string]interface{}{ + "expire_at": expireAt, + }) + + // If no rows were affected (not found), create a new record + if result.RowsAffected == 0 { + record := entity.CsrfTokenRecords{ + Token: key, + Value: value, + ExpireAt: expireAt, + CreatedAt: time.Now(), + } + + if err := s.DB.Create(&record).Error; err != nil { + log.Printf("CSRF Storage: Error saving token: %v", err) + return err + } + } else if result.Error != nil { + log.Printf("CSRF Storage: Error updating token: %v", result.Error) + return result.Error + } + + log.Printf("CSRF Storage: Successfully saved/updated token") + return nil +} + +func (s *PostgresStorage) Delete(key string) error { + return s.DB.Where("token = ?", key).Delete(&entity.CsrfTokenRecords{}).Error +} + +func (s *PostgresStorage) Reset() error { + return s.DB.Where("expire_at < ?", time.Now()).Delete(&entity.CsrfTokenRecords{}).Error +} + +func (s *PostgresStorage) Close() error { + return nil +} diff --git a/app/middleware/register.middleware.go b/app/middleware/register.middleware.go index fedb801..16434e8 100644 --- a/app/middleware/register.middleware.go +++ b/app/middleware/register.middleware.go @@ -1,9 +1,12 @@ package middleware import ( + "github.com/gofiber/fiber/v2/middleware/csrf" + "github.com/gofiber/fiber/v2/middleware/session" "go-humas-be/app/database" "go-humas-be/config/config" - "go-humas-be/utils" + utilsSvc "go-humas-be/utils" + "log" "time" "github.com/gofiber/fiber/v2" @@ -13,8 +16,7 @@ import ( "github.com/gofiber/fiber/v2/middleware/monitor" "github.com/gofiber/fiber/v2/middleware/pprof" "github.com/gofiber/fiber/v2/middleware/recover" - - auditTrails "go-humas-be/config/middleware" + "github.com/gofiber/fiber/v2/utils" ) // Middleware is a struct that contains all the middleware functions @@ -35,26 +37,26 @@ func (m *Middleware) Register(db *database.Database) { // Add Extra Middlewares m.App.Use(limiter.New(limiter.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Limiter.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Limiter.Enable), Max: m.Cfg.Middleware.Limiter.Max, Expiration: m.Cfg.Middleware.Limiter.Expiration * time.Second, })) m.App.Use(compress.New(compress.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Compress.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Compress.Enable), Level: m.Cfg.Middleware.Compress.Level, })) m.App.Use(recover.New(recover.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Recover.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Recover.Enable), })) m.App.Use(pprof.New(pprof.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Pprof.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Pprof.Enable), })) m.App.Use(cors.New(cors.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Cors.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Cors.Enable), AllowOrigins: "*", AllowMethods: "HEAD, GET, POST, PUT, DELETE, OPTION, PATCH", AllowHeaders: "Origin, Content-Type, Accept, Accept-Language, Authorization, X-Requested-With, Access-Control-Request-Method, Access-Control-Request-Headers", @@ -63,7 +65,64 @@ func (m *Middleware) Register(db *database.Database) { MaxAge: 12, })) - m.App.Use(auditTrails.AuditTrailsMiddleware(db.DB)) + //=============================== + // CSRF CONFIG + //=============================== + + // Custom storage for CSRF + csrfSessionStorage := &PostgresStorage{ + DB: db.DB, + } + + // Store initialization for session + store := session.New(session.Config{ + Storage: csrfSessionStorage, + }) + + m.App.Use(func(c *fiber.Ctx) error { + sess, err := store.Get(c) + if err != nil { + return err + } + c.Locals("session", sess) + return c.Next() + }) + + // Cleanup the expired token + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + if err := csrfSessionStorage.Reset(); err != nil { + log.Printf("Error cleaning up expired CSRF tokens: %v", err) + } + } + }() + + m.App.Use(csrf.New(csrf.Config{ + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Csrf.Enable), + KeyLookup: "header:" + csrf.HeaderName, + CookieName: "csrf_", + CookieSameSite: "Lax", + CookieSecure: false, + CookieSessionOnly: true, + CookieHTTPOnly: true, + Expiration: 1 * time.Hour, + KeyGenerator: utils.UUIDv4, + ContextKey: "csrf", + ErrorHandler: func(c *fiber.Ctx, err error) error { + return utilsSvc.CsrfErrorHandler(c, err) + }, + Extractor: csrf.CsrfFromHeader(csrf.HeaderName), + Session: store, + SessionKey: "fiber.csrf.token", + })) + + //=============================== + + m.App.Use(AuditTrailsMiddleware(db.DB)) + StartAuditTrailCleanup(db.DB, m.Cfg.Middleware.AuditTrails.Retention) //m.App.Use(filesystem.New(filesystem.Config{ // Next: utils.IsEnabled(m.Cfg.Middleware.FileSystem.Enable), @@ -72,7 +131,32 @@ func (m *Middleware) Register(db *database.Database) { // MaxAge: m.Cfg.Middleware.FileSystem.MaxAge, //})) + // ================================================== + m.App.Get(m.Cfg.Middleware.Monitor.Path, monitor.New(monitor.Config{ - Next: utils.IsEnabled(m.Cfg.Middleware.Monitor.Enable), + Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Monitor.Enable), })) + + // Route for generate CSRF token + m.App.Get("/csrf-token", func(c *fiber.Ctx) error { + // Retrieve CSRF token from Fiber's middleware context + token, ok := c.Locals("csrf").(string) + + //c.Context().VisitUserValues(func(key []byte, value interface{}) { + // log.Printf("Local Key: %s, Value: %v", key, value) + //}) + + if !ok || token == "" { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "success": false, + "code": 500, + "messages": []string{"Failed to retrieve CSRF token"}, + }) + } + + return c.JSON(fiber.Map{ + "success": true, + "csrf_token": token, + }) + }) } diff --git a/config/config/index.config.go b/config/config/index.config.go index 68252f0..200c697 100644 --- a/config/config/index.config.go +++ b/config/config/index.config.go @@ -73,9 +73,14 @@ type middleware = struct { Expiration time.Duration `toml:"expiration_seconds"` } - AuditTrails struct { + Csrf struct { Enable bool } + + AuditTrails struct { + Enable bool + Retention int + } } // minio struct config diff --git a/config/toml/config.toml b/config/toml/config.toml index c10f132..7a02f9f 100644 --- a/config/toml/config.toml +++ b/config/toml/config.toml @@ -53,8 +53,12 @@ enable = true max = 20 expiration_seconds = 60 +[middleware.csrf] +enable = true + [middleware.audittrails] enable = true +retention = 30 [keycloak] endpoint = "http://38.47.180.165:8008" diff --git a/utils/index.utils.go b/utils/index.utils.go index 6a98281..4d891a2 100644 --- a/utils/index.utils.go +++ b/utils/index.utils.go @@ -9,3 +9,9 @@ func IsEnabled(key bool) func(c *fiber.Ctx) bool { return func(c *fiber.Ctx) bool { return true } } + +func CsrfErrorHandler(c *fiber.Ctx, err error) error { + return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ + "error": "CSRF protection: " + err.Error(), + }) +}