kontenhumas-be/app/middleware/client.middleware.go

249 lines
7.0 KiB
Go

package middleware
import (
"netidhub-saas-be/app/database/entity"
"netidhub-saas-be/utils/client"
"strings"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
ClientKeyHeader = "X-Client-Key"
ClientContextKey = "client_id"
UserIDContextKey = "user_id"
IsSuperAdminContextKey = "is_super_admin"
AccessibleClientIDsKey = "accessible_client_ids"
CurrentClientIDKey = "current_client_id"
)
// excludedPaths contains paths that don't require client key validation
var excludedPaths = []string{
"/swagger/*",
"/docs/*",
"/users/login",
"/health/*",
"/clients",
"/clients/*",
"*/viewer/*",
"/bookmarks/test-table",
}
// isPathExcluded checks if the given path should be excluded from client key validation
func isPathExcluded(path string) bool {
for _, excludedPath := range excludedPaths {
if strings.HasPrefix(excludedPath, "*") && strings.HasSuffix(excludedPath, "*") {
// Handle wildcard at both beginning and end (e.g., "*/viewer/*")
pattern := excludedPath[1 : len(excludedPath)-1] // Remove * from both ends
if strings.Contains(path, pattern) {
return true
}
} else if strings.HasPrefix(excludedPath, "*") {
// Handle wildcard at the beginning
if strings.HasSuffix(path, excludedPath[1:]) {
return true
}
} else if strings.HasSuffix(excludedPath, "*") {
// Handle wildcard at the end
prefix := excludedPath[:len(excludedPath)-1]
if strings.HasPrefix(path, prefix) {
return true
}
} else {
// Exact match
if path == excludedPath {
return true
}
}
}
return false
}
// ClientMiddleware extracts and validates the Client Key from request headers
// Enhanced to support multi-client access and super admin
func ClientMiddleware(db *gorm.DB) fiber.Handler {
return func(c *fiber.Ctx) error {
// Check if path should be excluded from client key validation
if isPathExcluded(c.Path()) {
return c.Next()
}
// Check if user ID exists in context (set by auth middleware)
userID := c.Locals(UserIDContextKey)
if userID != nil {
// User authenticated - use multi-client logic
return handleAuthenticatedUser(c, db, userID)
}
// Fallback to X-Client-Key validation (backward compatibility)
return handleClientKeyValidation(c, db)
}
}
// handleAuthenticatedUser handles request from authenticated users
func handleAuthenticatedUser(c *fiber.Ctx, db *gorm.DB, userID interface{}) error {
userId, ok := userID.(uint)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"success": false,
"code": 401,
"messages": []string{"Invalid user ID in context"},
})
}
// Get user details
var user entity.Users
if err := db.Select("id, is_super_admin, client_id").
Where("id = ?", userId).
First(&user).Error; err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"success": false,
"code": 401,
"messages": []string{"User not found"},
})
}
// Store super admin status
isSuperAdmin := user.IsSuperAdmin != nil && *user.IsSuperAdmin
c.Locals(IsSuperAdminContextKey, isSuperAdmin)
// Get accessible client IDs for this user
accessibleClientIDs, err := client.GetAccessibleClientIDs(db, userId, isSuperAdmin)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"success": false,
"code": 500,
"messages": []string{"Error retrieving client access"},
})
}
// Store accessible client IDs in context
c.Locals(AccessibleClientIDsKey, accessibleClientIDs)
// Determine current client ID (from header or user's primary client)
var currentClientID *uuid.UUID
clientKeyHeader := c.Get(ClientKeyHeader)
if clientKeyHeader != "" {
// User specified a client via header
clientUUID, err := uuid.Parse(clientKeyHeader)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"success": false,
"code": 400,
"messages": []string{"Invalid Client Key format"},
})
}
// Verify user has access to this client
hasAccess, err := client.HasAccessToClient(db, userId, clientUUID, isSuperAdmin)
if err != nil || !hasAccess {
return c.Status(fiber.StatusForbidden).JSON(fiber.Map{
"success": false,
"code": 403,
"messages": []string{"Access denied to this client"},
})
}
currentClientID = &clientUUID
} else if user.ClientId != nil {
// Use user's primary client
currentClientID = user.ClientId
}
// Store current client ID
if currentClientID != nil {
c.Locals(CurrentClientIDKey, *currentClientID)
c.Locals(ClientContextKey, *currentClientID) // Backward compatibility
}
return c.Next()
}
// handleClientKeyValidation validates X-Client-Key header (backward compatibility)
func handleClientKeyValidation(c *fiber.Ctx, db *gorm.DB) error {
clientKey := c.Get(ClientKeyHeader)
if clientKey == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"success": false,
"code": 400,
"messages": []string{"Client Key is required in header: " + ClientKeyHeader},
})
}
clientUUID, err := uuid.Parse(clientKey)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"success": false,
"code": 400,
"messages": []string{"Invalid Client Key format"},
})
}
var clientEntity entity.Clients
if err := db.Where("id = ? AND is_active = ?", clientUUID, true).First(&clientEntity).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"success": false,
"code": 401,
"messages": []string{"Invalid or inactive Client Key"},
})
}
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"success": false,
"code": 500,
"messages": []string{"Error validating Client Key"},
})
}
c.Locals(ClientContextKey, clientUUID)
c.Locals(CurrentClientIDKey, clientUUID)
return c.Next()
}
// GetClientID retrieves the client ID from the context (backward compatibility)
func GetClientID(c *fiber.Ctx) *uuid.UUID {
if clientID, ok := c.Locals(ClientContextKey).(uuid.UUID); ok {
return &clientID
}
return nil
}
// GetAccessibleClientIDs retrieves all accessible client IDs from context
func GetAccessibleClientIDs(c *fiber.Ctx) []uuid.UUID {
if clientIDs, ok := c.Locals(AccessibleClientIDsKey).([]uuid.UUID); ok {
return clientIDs
}
return nil // nil = super admin or no restriction
}
// GetCurrentClientID retrieves the current working client ID from context
func GetCurrentClientID(c *fiber.Ctx) *uuid.UUID {
if clientID, ok := c.Locals(CurrentClientIDKey).(uuid.UUID); ok {
return &clientID
}
return nil
}
// IsSuperAdmin checks if the current user is a super admin
func IsSuperAdmin(c *fiber.Ctx) bool {
if isSuperAdmin, ok := c.Locals(IsSuperAdminContextKey).(bool); ok {
return isSuperAdmin
}
return false
}
// AddExcludedPath adds a new path to the excluded paths list
func AddExcludedPath(path string) {
excludedPaths = append(excludedPaths, path)
}
// GetExcludedPaths returns the current list of excluded paths
func GetExcludedPaths() []string {
return excludedPaths
}