""" Flask middleware for request/response logging and monitoring. This module provides middleware to automatically log HTTP requests, responses, performance metrics, and security events. """ import time import json from datetime import datetime, timezone from flask import request, g, current_app from flask_login import current_user from logging_config import ( performance_logger, security_logger, app_logger, set_correlation_id, get_correlation_id, log_context ) class RequestLoggingMiddleware: """Middleware for comprehensive request/response logging.""" def __init__(self, app=None): self.app = app if app is not None: self.init_app(app) def init_app(self, app): """Initialize the middleware with Flask app.""" app.before_request(self.before_request) app.after_request(self.after_request) app.teardown_appcontext(self.teardown_request) def before_request(self): """Called before each request.""" # Skip logging for static files if request.endpoint == 'static' or request.path.startswith('/static/'): return # Set correlation ID for request tracking correlation_id = request.headers.get('X-Correlation-ID') or set_correlation_id() g.correlation_id = correlation_id g.start_time = time.time() # Get client information g.client_ip = self.get_client_ip() g.user_agent = request.headers.get('User-Agent', 'Unknown') # Log request details self.log_request_start() # Security monitoring for sensitive endpoints self.check_security_events() def after_request(self, response): """Called after each request.""" # Skip logging for static files if request.endpoint == 'static' or request.path.startswith('/static/'): return response if hasattr(g, 'start_time'): duration_ms = (time.time() - g.start_time) * 1000 # Log response details self.log_request_complete(response, duration_ms) # Log performance metrics self.log_performance_metrics(response, duration_ms) # Add correlation ID to response headers response.headers['X-Correlation-ID'] = get_correlation_id() return response def teardown_request(self, exception=None): """Called when request context is torn down.""" try: # Skip logging for static files if request.endpoint == 'static' or request.path.startswith('/static/'): return if exception: app_logger.error( f"Request failed with exception: {str(exception)}", endpoint=request.endpoint, method=request.method, path=request.path, exception_type=type(exception).__name__, user_id=self.get_user_id() ) except RuntimeError: # Request context is no longer available, skip logging pass except Exception as e: # Don't let logging errors break the teardown pass def get_client_ip(self): """Get the real client IP address.""" # Check for forwarded headers (reverse proxy) if request.headers.get('X-Forwarded-For'): return request.headers.get('X-Forwarded-For').split(',')[0].strip() elif request.headers.get('X-Real-IP'): return request.headers.get('X-Real-IP') else: return request.environ.get('REMOTE_ADDR', 'Unknown') def get_user_id(self): """Get current user ID if authenticated.""" try: if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated: return current_user.id return None except: return None def log_request_start(self): """Log the start of a request.""" app_logger.info( f"REQUEST_START: {request.method} {request.path}", method=request.method, path=request.path, endpoint=request.endpoint, client_ip=g.client_ip, user_agent=g.user_agent, user_id=self.get_user_id(), query_params=dict(request.args) if request.args else None ) def log_request_complete(self, response, duration_ms): """Log the completion of a request.""" app_logger.info( f"REQUEST_COMPLETE: {request.method} {request.path} - {response.status_code} - {duration_ms:.2f}ms", method=request.method, path=request.path, endpoint=request.endpoint, status_code=response.status_code, duration_ms=duration_ms, response_size=len(response.get_data()) if hasattr(response, 'get_data') else None, client_ip=g.client_ip, user_id=self.get_user_id() ) def log_performance_metrics(self, response, duration_ms): """Log performance metrics for the request.""" # Log slow requests (> 1 second) if duration_ms > 1000: performance_logger.logger.warning( f"SLOW_REQUEST: {request.method} {request.path} took {duration_ms:.2f}ms", method=request.method, path=request.path, endpoint=request.endpoint, duration_ms=duration_ms, status_code=response.status_code, user_id=self.get_user_id() ) # Log to performance logger performance_logger.log_request_time( endpoint=request.endpoint or request.path, method=request.method, duration_ms=duration_ms, status_code=response.status_code, user_id=self.get_user_id() ) def check_security_events(self): """Check for potential security events.""" # Monitor for suspicious patterns suspicious_patterns = [ 'admin', 'login', 'auth', 'password', 'token', 'api', 'delete', 'drop', 'truncate', 'insert', 'update', 'select', 'union', 'script', 'alert' ] path_lower = request.path.lower() query_string = request.query_string.decode('utf-8', errors='ignore').lower() # Check for SQL injection attempts sql_patterns = ['union', 'select', 'insert', 'delete', 'drop', 'truncate', '--', ';'] if any(pattern in query_string for pattern in sql_patterns): security_logger.logger.warning( f"POTENTIAL_SQL_INJECTION: Suspicious query string detected", path=request.path, query_string=request.query_string.decode('utf-8', errors='ignore'), client_ip=g.client_ip, user_agent=g.user_agent, user_id=self.get_user_id() ) # Check for XSS attempts xss_patterns = [' 100ms) if duration_ms > 100: performance_logger.logger.warning( f"SLOW_QUERY: {query_type} on {table_name} took {duration_ms:.2f}ms", query_type=query_type, table=table_name, duration_ms=duration_ms, statement=statement[:200] + "..." if len(statement) > 200 else statement ) # Log all database queries performance_logger.log_database_query( query_type=query_type, table=table_name, duration_ms=duration_ms ) def extract_table_name(self, statement): """Extract table name from SQL statement.""" try: statement_upper = statement.upper() if 'FROM ' in statement_upper: parts = statement_upper.split('FROM ')[1].split() return parts[0].strip('`"[]') if parts else 'unknown' elif 'INTO ' in statement_upper: parts = statement_upper.split('INTO ')[1].split() return parts[0].strip('`"[]') if parts else 'unknown' elif 'UPDATE ' in statement_upper: parts = statement_upper.split('UPDATE ')[1].split() return parts[0].strip('`"[]') if parts else 'unknown' else: return 'unknown' except: return 'unknown' class SecurityMiddleware: """Specialized middleware for security monitoring.""" def __init__(self, app=None): self.app = app if app is not None: self.init_app(app) def init_app(self, app): """Initialize security middleware.""" app.before_request(self.security_check) def security_check(self): """Perform security checks on each request.""" # Skip security checks for static files if request.endpoint == 'static' or request.path.startswith('/static/'): return # Rate limiting check (basic implementation) self.check_rate_limiting() # Check for banned IPs or user agents self.check_blacklist() # Monitor for admin access self.monitor_admin_access() def check_rate_limiting(self): """Basic rate limiting check.""" # This would typically use Redis for distributed rate limiting client_ip = request.environ.get('REMOTE_ADDR', 'Unknown') # For now, just log high-frequency requests if hasattr(g, 'request_count'): g.request_count += 1 else: g.request_count = 1 def check_blacklist(self): """Check if IP or user agent is blacklisted.""" client_ip = request.environ.get('REMOTE_ADDR', 'Unknown') user_agent = request.headers.get('User-Agent', 'Unknown') # Example blacklist (would be configurable) blacklisted_ips = [] suspicious_agents = ['sqlmap', 'nikto', 'nmap', 'masscan'] if client_ip in blacklisted_ips: security_logger.logger.critical( f"BLACKLISTED_IP: Access attempt from blacklisted IP", client_ip=client_ip, path=request.path, user_agent=user_agent ) if any(agent in user_agent.lower() for agent in suspicious_agents): security_logger.logger.warning( f"SUSPICIOUS_AGENT: Request with suspicious user agent", client_ip=client_ip, user_agent=user_agent, path=request.path ) def monitor_admin_access(self): """Monitor access to admin endpoints.""" admin_endpoints = ['/admin', '/management', '/config', '/settings'] if any(endpoint in request.path for endpoint in admin_endpoints): try: user_id = current_user.id if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated else None except: user_id = None security_logger.logger.info( f"ADMIN_ACCESS: Access to admin endpoint", path=request.path, client_ip=request.environ.get('REMOTE_ADDR', 'Unknown'), user_agent=request.headers.get('User-Agent', 'Unknown'), user_id=user_id )