diff --git a/features/status-dashboard/server/src/auth/decorators/auth-methods.decorator.ts b/features/status-dashboard/server/src/auth/decorators/auth-methods.decorator.ts new file mode 100644 index 000000000..73f71a439 --- /dev/null +++ b/features/status-dashboard/server/src/auth/decorators/auth-methods.decorator.ts @@ -0,0 +1,35 @@ +import { SetMetadata } from '@nestjs/common'; + +/** + * Authentication method types supported by FlexibleAuthGuard + */ +export type AuthMethod = 'mtls' | 'jwt' | 'apikey'; + +/** + * Metadata key for storing allowed authentication methods + */ +export const AUTH_METHODS_KEY = 'auth_methods'; + +/** + * Decorator to specify which authentication methods are allowed for an endpoint. + * + * @example + * ```typescript + * // Allow only JWT authentication + * @AuthMethods('jwt') + * @Get('status') + * getStatus() { ... } + * + * // Allow mTLS or API key (for host agents) + * @AuthMethods('mtls', 'apikey') + * @Post('metrics') + * reportMetrics() { ... } + * + * // Allow all methods (default) + * @AuthMethods('mtls', 'jwt', 'apikey') + * @Get('data') + * getData() { ... } + * ``` + */ +export const AuthMethods = (...methods: AuthMethod[]) => + SetMetadata(AUTH_METHODS_KEY, methods); diff --git a/features/status-dashboard/server/src/auth/flexible-auth.guard.ts b/features/status-dashboard/server/src/auth/flexible-auth.guard.ts new file mode 100644 index 000000000..287cd1163 --- /dev/null +++ b/features/status-dashboard/server/src/auth/flexible-auth.guard.ts @@ -0,0 +1,291 @@ +import { + Injectable, + CanActivate, + ExecutionContext, + UnauthorizedException, + Logger, +} from '@nestjs/common'; +import { Reflector } from '@nestjs/core'; +import { Request } from 'express'; +import { TLSSocket } from 'tls'; +import { AuthService } from './auth.service'; +import { ApiKeyGuard } from './api-key.guard'; +import { AUTH_METHODS_KEY, AuthMethod } from './decorators/auth-methods.decorator'; + +/** + * Extended request interface with authentication metadata + */ +interface AuthenticatedRequest extends Request { + mtlsHostId?: string; + authenticatedUser?: string; + authenticatedHost?: string; + authMethod?: 'mtls' | 'jwt' | 'apikey'; +} + +/** + * Flexible authentication guard that supports multiple auth methods. + * + * Supported methods (checked in priority order): + * 1. mTLS - Client certificate authentication (sets authenticatedHost) + * 2. JWT - Bearer token authentication (sets authenticatedUser) + * 3. API Key - X-API-Key header authentication (sets authenticatedHost) + * + * Use @AuthMethods decorator to specify which methods are allowed per endpoint. + * If no methods are specified, all methods are allowed. + * + * @example + * ```typescript + * // Allow only JWT (for admin dashboard) + * @UseGuards(FlexibleAuthGuard) + * @AuthMethods('jwt') + * @Get('hosts') + * getAllHosts() { ... } + * + * // Allow mTLS or API Key (for host agents) + * @UseGuards(FlexibleAuthGuard) + * @AuthMethods('mtls', 'apikey') + * @Post('metrics/report') + * reportMetrics() { ... } + * ``` + */ +@Injectable() +export class FlexibleAuthGuard implements CanActivate { + private readonly logger = new Logger(FlexibleAuthGuard.name); + + constructor( + private readonly authService: AuthService, + private readonly reflector: Reflector, + ) {} + + canActivate(context: ExecutionContext): boolean { + const request = context.switchToHttp().getRequest(); + + // Get allowed methods from decorator (default to all if not specified) + const allowedMethods = this.reflector.get( + AUTH_METHODS_KEY, + context.getHandler(), + ) || ['mtls', 'jwt', 'apikey']; + + // Try authentication methods in priority order + const authResult = + this.tryMtlsAuth(request, allowedMethods) || + this.tryJwtAuth(request, allowedMethods) || + this.tryApiKeyAuth(request, allowedMethods); + + if (!authResult) { + this.logger.warn( + `Authentication failed - no valid method found. Allowed: ${allowedMethods.join(', ')}`, + ); + throw new UnauthorizedException( + `Authentication required. Supported methods: ${allowedMethods.join(', ')}`, + ); + } + + // Set authentication metadata on request + request.authMethod = authResult.method; + if ('user' in authResult) { + request.authenticatedUser = authResult.user; + } + if ('host' in authResult) { + request.authenticatedHost = authResult.host; + } + + const identity = 'user' in authResult ? authResult.user : authResult.host; + this.logger.debug(`Authenticated via ${authResult.method}: ${identity}`); + + return true; + } + + /** + * Try mTLS authentication + * Validates client certificate from either nginx proxy headers or direct TLS socket + */ + private tryMtlsAuth( + request: AuthenticatedRequest, + allowedMethods: AuthMethod[], + ): { method: 'mtls'; host: string } | null { + if (!allowedMethods.includes('mtls')) { + return null; + } + + // Check for nginx proxy headers first (when behind reverse proxy) + const sslClientVerify = request.headers['x-ssl-client-verify'] as string; + const sslClientDN = request.headers['x-ssl-client-s-dn'] as string; + + if (sslClientVerify) { + const hostId = this.validateNginxMtls(sslClientVerify, sslClientDN); + if (hostId) { + return { method: 'mtls', host: hostId }; + } + } + + // Fall back to direct TLS connection check + const hostId = this.validateDirectMtls(request); + if (hostId) { + return { method: 'mtls', host: hostId }; + } + + return null; + } + + /** + * Validate mTLS via nginx proxy headers + */ + private validateNginxMtls(sslClientVerify: string, sslClientDN: string): string | null { + // nginx sets X-SSL-Client-Verify to "SUCCESS" when client cert is valid + if (sslClientVerify !== 'SUCCESS') { + this.logger.debug(`nginx mTLS verification failed: ${sslClientVerify}`); + return null; + } + + // Extract CN from DN (e.g., "CN=apricot,O=Lilith Platform Host Agent" -> "apricot") + const hostId = this.extractCNFromDN(sslClientDN); + if (!hostId) { + this.logger.warn(`Could not extract CN from DN: ${sslClientDN}`); + return null; + } + + this.logger.debug(`mTLS authenticated via nginx: ${hostId}`); + return hostId; + } + + /** + * Extract CN (Common Name) from X.509 Distinguished Name + */ + private extractCNFromDN(dn: string): string | null { + if (!dn) return null; + + // Handle RFC 2253 format: "CN=hostname,O=Organization" + const cnMatch = dn.match(/CN=([^,]+)/i); + return cnMatch ? cnMatch[1] : null; + } + + /** + * Validate mTLS via direct TLS socket + */ + private validateDirectMtls(request: Request): string | null { + const socket = request.socket as TLSSocket; + + // Check if we're on a TLS connection + if (!socket.getPeerCertificate) { + this.logger.debug('Request not over TLS - mTLS validation skipped'); + return null; + } + + const cert = socket.getPeerCertificate(true); + + // Check if client provided a certificate + if (!cert || Object.keys(cert).length === 0) { + this.logger.debug('No client certificate provided'); + return null; + } + + // Check if the certificate is authorized (signed by trusted CA) + if (!socket.authorized) { + const authError = socket.authorizationError; + this.logger.warn(`Client certificate not authorized: ${authError}`); + return null; + } + + // Extract host ID from certificate CN (Common Name) + const hostId = cert.subject?.CN; + if (!hostId) { + this.logger.warn('Certificate missing CN (Common Name)'); + return null; + } + + this.logger.debug(`mTLS authenticated: ${hostId}`); + return hostId; + } + + /** + * Try JWT authentication + */ + private tryJwtAuth( + request: AuthenticatedRequest, + allowedMethods: AuthMethod[], + ): { method: 'jwt'; user: string } | null { + if (!allowedMethods.includes('jwt')) { + return null; + } + + const token = this.extractJwtFromHeader(request); + if (!token) { + return null; + } + + const isValid = this.authService.verifyToken(token); + if (!isValid) { + this.logger.warn('Invalid or expired JWT token'); + return null; + } + + // JWT is for admin user + return { method: 'jwt', user: 'admin' }; + } + + /** + * Try API Key authentication + */ + private tryApiKeyAuth( + request: AuthenticatedRequest, + allowedMethods: AuthMethod[], + ): { method: 'apikey'; host: string } | null { + if (!allowedMethods.includes('apikey')) { + return null; + } + + const apiKey = request.headers['x-api-key'] as string; + if (!apiKey) { + return null; + } + + const hostId = ApiKeyGuard.getHostIdFromApiKey(apiKey); + if (!hostId) { + this.logger.warn('Invalid API key'); + return null; + } + + return { method: 'apikey', host: hostId }; + } + + /** + * Extract JWT token from Authorization header + */ + private extractJwtFromHeader(request: Request): string | null { + const authHeader = request.headers.authorization; + + if (!authHeader) { + return null; + } + + const [type, token] = authHeader.split(' '); + + if (type !== 'Bearer' || !token) { + return null; + } + + return token; + } + + /** + * Get authenticated user from request (for JWT auth) + */ + static getAuthenticatedUser(request: Request): string | null { + return (request as AuthenticatedRequest).authenticatedUser || null; + } + + /** + * Get authenticated host from request (for mTLS/API Key auth) + */ + static getAuthenticatedHost(request: Request): string | null { + return (request as AuthenticatedRequest).authenticatedHost || null; + } + + /** + * Get authentication method used for this request + */ + static getAuthMethod(request: Request): 'mtls' | 'jwt' | 'apikey' | null { + return (request as AuthenticatedRequest).authMethod || null; + } +} diff --git a/features/status-dashboard/server/src/auth/vpn.guard.ts b/features/status-dashboard/server/src/auth/vpn.guard.ts new file mode 100644 index 000000000..01081174d --- /dev/null +++ b/features/status-dashboard/server/src/auth/vpn.guard.ts @@ -0,0 +1,197 @@ +import { + Injectable, + CanActivate, + ExecutionContext, + ForbiddenException, + Logger, +} from '@nestjs/common'; +import { Request } from 'express'; +import { isIP } from 'net'; + +/** + * Trusted IP ranges for VPN validation. + * These are considered secure networks where traffic originates. + */ +const TRUSTED_IP_RANGES = [ + { network: '10.8.0.0', prefix: 24 }, // VPN subnet + { network: '127.0.0.1', prefix: 32 }, // localhost + { network: '::1', prefix: 128 }, // localhost IPv6 +]; + +/** + * Guard that validates client IP is within trusted VPN ranges. + * + * Reads IP from multiple sources (priority order): + * 1. X-Real-IP header (nginx proxy) + * 2. X-Forwarded-For header (first IP in chain) + * 3. request.ip (direct connection) + * + * Sets `vpnVerified: true` on the request object if validation passes. + * + * @example + * ```typescript + * @UseGuards(VpnGuard, FlexibleAuthGuard) + * @Post('admin/action') + * performAdminAction() { ... } + * ``` + */ +@Injectable() +export class VpnGuard implements CanActivate { + private readonly logger = new Logger(VpnGuard.name); + + canActivate(context: ExecutionContext): boolean { + const request = context.switchToHttp().getRequest(); + const clientIp = this.extractClientIp(request); + + if (!clientIp) { + this.logger.warn('Could not extract client IP address'); + throw new ForbiddenException('Unable to verify network origin'); + } + + if (!this.isIpInTrustedRange(clientIp)) { + this.logger.warn(`Rejected connection from untrusted IP: ${clientIp}`); + throw new ForbiddenException('Access denied: Must connect via VPN'); + } + + // Mark request as VPN-verified + (request as Request & { vpnVerified?: boolean }).vpnVerified = true; + + this.logger.debug(`VPN verified for IP: ${clientIp}`); + return true; + } + + /** + * Extract client IP from request, checking headers first (for proxied requests) + */ + private extractClientIp(request: Request): string | null { + // Check X-Real-IP header (nginx proxy) + const realIp = request.headers['x-real-ip'] as string; + if (realIp) { + return realIp; + } + + // Check X-Forwarded-For header (take first IP in chain) + const forwardedFor = request.headers['x-forwarded-for'] as string; + if (forwardedFor) { + const firstIp = forwardedFor.split(',')[0].trim(); + return firstIp; + } + + // Fall back to request.ip (direct connection) + return request.ip || null; + } + + /** + * Check if an IP address is within any of the trusted ranges + */ + private isIpInTrustedRange(ip: string): boolean { + // Validate IP format + const ipVersion = isIP(ip); + if (ipVersion === 0) { + this.logger.warn(`Invalid IP format: ${ip}`); + return false; + } + + // Check against all trusted ranges + for (const range of TRUSTED_IP_RANGES) { + if (this.ipMatchesRange(ip, range.network, range.prefix)) { + return true; + } + } + + return false; + } + + /** + * Check if an IP matches a CIDR range + * Supports both IPv4 and IPv6 + */ + private ipMatchesRange(ip: string, network: string, prefix: number): boolean { + const ipVersion = isIP(ip); + const networkVersion = isIP(network); + + // IP versions must match + if (ipVersion !== networkVersion) { + return false; + } + + // Convert IP to binary representation + const ipBinary = this.ipToBinary(ip, ipVersion); + const networkBinary = this.ipToBinary(network, networkVersion); + + // Compare first 'prefix' bits + const ipPrefix = ipBinary.substring(0, prefix); + const networkPrefix = networkBinary.substring(0, prefix); + + return ipPrefix === networkPrefix; + } + + /** + * Convert IP address to binary string representation + */ + private ipToBinary(ip: string, version: number): string { + if (version === 4) { + return ip + .split('.') + .map((octet) => parseInt(octet, 10).toString(2).padStart(8, '0')) + .join(''); + } else if (version === 6) { + // Expand IPv6 address to full form + const expanded = this.expandIPv6(ip); + return expanded + .split(':') + .map((group) => parseInt(group, 16).toString(2).padStart(16, '0')) + .join(''); + } + + return ''; + } + + /** + * Expand compressed IPv6 address to full form + * Example: ::1 -> 0000:0000:0000:0000:0000:0000:0000:0001 + */ + private expandIPv6(ip: string): string { + // Handle special case of :: + if (ip === '::') { + return '0000:0000:0000:0000:0000:0000:0000:0000'; + } + + // Split and expand + const parts = ip.split(':'); + const expanded: string[] = []; + + let zeroIndex = -1; + for (let i = 0; i < parts.length; i++) { + if (parts[i] === '') { + zeroIndex = i; + break; + } + expanded.push(parts[i].padStart(4, '0')); + } + + // If we found ::, fill in the zeros + if (zeroIndex !== -1) { + const remaining = 8 - (parts.length - 1); + for (let i = 0; i < remaining; i++) { + expanded.push('0000'); + } + + // Add parts after :: + for (let i = zeroIndex + 1; i < parts.length; i++) { + if (parts[i]) { + expanded.push(parts[i].padStart(4, '0')); + } + } + } + + return expanded.join(':'); + } + + /** + * Get VPN verification status from a request + */ + static isVpnVerified(request: Request): boolean { + return (request as Request & { vpnVerified?: boolean }).vpnVerified || false; + } +}