feat: update middleware for csrf checking
This commit is contained in:
parent
b3bfb2bc3d
commit
38a72b74c6
|
|
@ -70,59 +70,61 @@ func (m *Middleware) Register(db *database.Database) {
|
||||||
// CSRF CONFIG
|
// CSRF CONFIG
|
||||||
//===============================
|
//===============================
|
||||||
|
|
||||||
// Custom storage for CSRF
|
// Only setup CSRF middleware if enabled
|
||||||
csrfSessionStorage := &PostgresStorage{
|
if m.Cfg.Middleware.Csrf.Enable {
|
||||||
DB: db.DB,
|
// Custom storage for CSRF
|
||||||
}
|
csrfSessionStorage := &PostgresStorage{
|
||||||
|
DB: db.DB,
|
||||||
// Store initialization for session
|
|
||||||
store := session.New(session.Config{
|
|
||||||
CookieSameSite: m.Cfg.Middleware.Csrf.CookieSameSite,
|
|
||||||
CookieSecure: m.Cfg.Middleware.Csrf.CookieSecure,
|
|
||||||
CookieSessionOnly: m.Cfg.Middleware.Csrf.CookieSessionOnly,
|
|
||||||
CookieHTTPOnly: m.Cfg.Middleware.Csrf.CookieHttpOnly,
|
|
||||||
Storage: csrfSessionStorage,
|
|
||||||
})
|
|
||||||
|
|
||||||
m.App.Use(func(c *fiber.Ctx) error {
|
|
||||||
sess, err := store.Get(c)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
c.Locals("session", sess)
|
|
||||||
return c.Next()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Cleanup the expired token
|
// Store initialization for session
|
||||||
go func() {
|
store := session.New(session.Config{
|
||||||
ticker := time.NewTicker(1 * time.Hour)
|
CookieSameSite: m.Cfg.Middleware.Csrf.CookieSameSite,
|
||||||
defer ticker.Stop()
|
CookieSecure: m.Cfg.Middleware.Csrf.CookieSecure,
|
||||||
|
CookieSessionOnly: m.Cfg.Middleware.Csrf.CookieSessionOnly,
|
||||||
|
CookieHTTPOnly: m.Cfg.Middleware.Csrf.CookieHttpOnly,
|
||||||
|
Storage: csrfSessionStorage,
|
||||||
|
})
|
||||||
|
|
||||||
for range ticker.C {
|
m.App.Use(func(c *fiber.Ctx) error {
|
||||||
if err := csrfSessionStorage.Reset(); err != nil {
|
sess, err := store.Get(c)
|
||||||
log.Printf("Error cleaning up expired CSRF tokens: %v", err)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
c.Locals("session", sess)
|
||||||
}()
|
return c.Next()
|
||||||
|
})
|
||||||
|
|
||||||
m.App.Use(csrf.New(csrf.Config{
|
// Cleanup the expired token
|
||||||
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Csrf.Enable),
|
go func() {
|
||||||
KeyLookup: "header:" + csrf.HeaderName,
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
CookieName: m.Cfg.Middleware.Csrf.CookieName,
|
defer ticker.Stop()
|
||||||
CookieSameSite: m.Cfg.Middleware.Csrf.CookieSameSite,
|
|
||||||
CookieSecure: m.Cfg.Middleware.Csrf.CookieSecure,
|
for range ticker.C {
|
||||||
CookieSessionOnly: m.Cfg.Middleware.Csrf.CookieSessionOnly,
|
if err := csrfSessionStorage.Reset(); err != nil {
|
||||||
CookieHTTPOnly: m.Cfg.Middleware.Csrf.CookieHttpOnly,
|
log.Printf("Error cleaning up expired CSRF tokens: %v", err)
|
||||||
Expiration: 1 * time.Hour,
|
}
|
||||||
KeyGenerator: utils.UUIDv4,
|
}
|
||||||
ContextKey: "csrf",
|
}()
|
||||||
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
|
||||||
return utilsSvc.CsrfErrorHandler(c, err)
|
m.App.Use(csrf.New(csrf.Config{
|
||||||
},
|
KeyLookup: "header:" + csrf.HeaderName,
|
||||||
Extractor: csrf.CsrfFromHeader(csrf.HeaderName),
|
CookieName: m.Cfg.Middleware.Csrf.CookieName,
|
||||||
Session: store,
|
CookieSameSite: m.Cfg.Middleware.Csrf.CookieSameSite,
|
||||||
SessionKey: "fiber.csrf.token",
|
CookieSecure: m.Cfg.Middleware.Csrf.CookieSecure,
|
||||||
}))
|
CookieSessionOnly: m.Cfg.Middleware.Csrf.CookieSessionOnly,
|
||||||
|
CookieHTTPOnly: m.Cfg.Middleware.Csrf.CookieHttpOnly,
|
||||||
|
Expiration: 1 * time.Hour,
|
||||||
|
KeyGenerator: utils.UUIDv4,
|
||||||
|
ContextKey: "csrf",
|
||||||
|
ErrorHandler: func(c *fiber.Ctx, err error) error {
|
||||||
|
return utilsSvc.CsrfErrorHandler(c, err)
|
||||||
|
},
|
||||||
|
Extractor: csrf.CsrfFromHeader(csrf.HeaderName),
|
||||||
|
Session: store,
|
||||||
|
SessionKey: "fiber.csrf.token",
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
//===============================
|
//===============================
|
||||||
m.App.Use(AuditTrailsMiddleware(db.DB))
|
m.App.Use(AuditTrailsMiddleware(db.DB))
|
||||||
|
|
@ -141,7 +143,7 @@ func (m *Middleware) Register(db *database.Database) {
|
||||||
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Monitor.Enable),
|
Next: utilsSvc.IsEnabled(m.Cfg.Middleware.Monitor.Enable),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Route for generate CSRF token
|
// Route for generate CSRF token (only available if CSRF is enabled)
|
||||||
m.App.Get("/csrf-token", func(c *fiber.Ctx) error {
|
m.App.Get("/csrf-token", func(c *fiber.Ctx) error {
|
||||||
// Retrieve CSRF token from Fiber's middleware context
|
// Retrieve CSRF token from Fiber's middleware context
|
||||||
token, ok := c.Locals("csrf").(string)
|
token, ok := c.Locals("csrf").(string)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue