package middleware import ( "encoding/json" "log" "narasi-ahli-be/app/database/entity" utilSvc "narasi-ahli-be/utils/service" "strings" "time" "github.com/gofiber/fiber/v2" "gorm.io/gorm" ) func AuditTrailsMiddleware(db *gorm.DB) fiber.Handler { return func(c *fiber.Ctx) error { start := time.Now() requestBody := c.Body() headersMap := c.GetReqHeaders() headersJSON, _ := json.Marshal(headersMap) authHeader := c.Get("Authorization") userId := utilSvc.GetUserId(authHeader) // Execute the next handler err := c.Next() // Get status code - ensure it's set correctly for errors statusCode := c.Response().StatusCode() if err != nil { // If error occurred, ensure status code reflects the error // The error handler should have set this, but if not, default to 500 if statusCode == fiber.StatusOK || statusCode == 0 { statusCode = fiber.StatusInternalServerError } } // Get response body responseBody := c.Response().Body() // If response body is empty and there's an error, create error response if len(responseBody) == 0 && err != nil { // Create error response JSON matching the error handler format errorResp := map[string]interface{}{ "success": false, "code": statusCode, "message": err.Error(), } if errorJSON, marshalErr := json.Marshal(errorResp); marshalErr == nil { responseBody = errorJSON } else { responseBody = []byte(err.Error()) } } audit := entity.AuditTrails{ Method: c.Method(), Path: c.OriginalURL(), IP: getIP(c), Status: statusCode, UserID: userId, RequestHeaders: string(headersJSON), RequestBody: string(requestBody), ResponseBody: string(responseBody), DurationMs: time.Since(start).Milliseconds(), CreatedAt: time.Now(), } // Save audit trail - use goroutine to avoid blocking // IMPORTANT: Save synchronously to ensure it completes even if app crashes // Using goroutine but with proper error handling go func(auditRecord entity.AuditTrails) { if saveErr := db.Create(&auditRecord).Error; saveErr != nil { log.Printf("Failed to save audit trail for %s %s (status: %d): %v", auditRecord.Method, auditRecord.Path, auditRecord.Status, saveErr) } }(audit) return err } } func StartAuditTrailCleanup(db *gorm.DB, retention int) { go func() { for { time.Sleep(24 * time.Hour) cutoff := time.Now().AddDate(0, 0, retention) db.Where("created_at < ?", cutoff).Delete(&entity.AuditTrails{}) log.Printf(" at: %s", cutoff) } }() } func getIP(c *fiber.Ctx) string { ip := c.Get("X-Forwarded-For") if ip == "" { ip = c.IP() } if strings.Contains(ip, ":") { ip = strings.Split(ip, ":")[0] } return ip }