You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

361 lines
14 KiB

"""
Flask middleware for request/response logging and monitoring.
This module provides middleware to automatically log HTTP requests, responses,
performance metrics, and security events.
"""
import time
import json
from datetime import datetime, timezone
from flask import request, g, current_app
from flask_login import current_user
from logging_config import (
performance_logger, security_logger, app_logger,
set_correlation_id, get_correlation_id, log_context
)
class RequestLoggingMiddleware:
"""Middleware for comprehensive request/response logging."""
def __init__(self, app=None):
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""Initialize the middleware with Flask app."""
app.before_request(self.before_request)
app.after_request(self.after_request)
app.teardown_appcontext(self.teardown_request)
def before_request(self):
"""Called before each request."""
# Skip logging for static files
if request.endpoint == 'static' or request.path.startswith('/static/'):
return
# Set correlation ID for request tracking
correlation_id = request.headers.get('X-Correlation-ID') or set_correlation_id()
g.correlation_id = correlation_id
g.start_time = time.time()
# Get client information
g.client_ip = self.get_client_ip()
g.user_agent = request.headers.get('User-Agent', 'Unknown')
# Log request details
self.log_request_start()
# Security monitoring for sensitive endpoints
self.check_security_events()
def after_request(self, response):
"""Called after each request."""
# Skip logging for static files
if request.endpoint == 'static' or request.path.startswith('/static/'):
return response
if hasattr(g, 'start_time'):
duration_ms = (time.time() - g.start_time) * 1000
# Log response details
self.log_request_complete(response, duration_ms)
# Log performance metrics
self.log_performance_metrics(response, duration_ms)
# Add correlation ID to response headers
response.headers['X-Correlation-ID'] = get_correlation_id()
return response
def teardown_request(self, exception=None):
"""Called when request context is torn down."""
try:
# Skip logging for static files
if request.endpoint == 'static' or request.path.startswith('/static/'):
return
if exception:
app_logger.error(
f"Request failed with exception: {str(exception)}",
endpoint=request.endpoint,
method=request.method,
path=request.path,
exception_type=type(exception).__name__,
user_id=self.get_user_id()
)
except RuntimeError:
# Request context is no longer available, skip logging
pass
except Exception as e:
# Don't let logging errors break the teardown
pass
def get_client_ip(self):
"""Get the real client IP address."""
# Check for forwarded headers (reverse proxy)
if request.headers.get('X-Forwarded-For'):
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
elif request.headers.get('X-Real-IP'):
return request.headers.get('X-Real-IP')
else:
return request.environ.get('REMOTE_ADDR', 'Unknown')
def get_user_id(self):
"""Get current user ID if authenticated."""
try:
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
return current_user.id
return None
except:
return None
def log_request_start(self):
"""Log the start of a request."""
app_logger.info(
f"REQUEST_START: {request.method} {request.path}",
method=request.method,
path=request.path,
endpoint=request.endpoint,
client_ip=g.client_ip,
user_agent=g.user_agent,
user_id=self.get_user_id(),
query_params=dict(request.args) if request.args else None
)
def log_request_complete(self, response, duration_ms):
"""Log the completion of a request."""
app_logger.info(
f"REQUEST_COMPLETE: {request.method} {request.path} - {response.status_code} - {duration_ms:.2f}ms",
method=request.method,
path=request.path,
endpoint=request.endpoint,
status_code=response.status_code,
duration_ms=duration_ms,
response_size=len(response.get_data()) if hasattr(response, 'get_data') else None,
client_ip=g.client_ip,
user_id=self.get_user_id()
)
def log_performance_metrics(self, response, duration_ms):
"""Log performance metrics for the request."""
# Log slow requests (> 1 second)
if duration_ms > 1000:
performance_logger.logger.warning(
f"SLOW_REQUEST: {request.method} {request.path} took {duration_ms:.2f}ms",
method=request.method,
path=request.path,
endpoint=request.endpoint,
duration_ms=duration_ms,
status_code=response.status_code,
user_id=self.get_user_id()
)
# Log to performance logger
performance_logger.log_request_time(
endpoint=request.endpoint or request.path,
method=request.method,
duration_ms=duration_ms,
status_code=response.status_code,
user_id=self.get_user_id()
)
def check_security_events(self):
"""Check for potential security events."""
# Monitor for suspicious patterns
suspicious_patterns = [
'admin', 'login', 'auth', 'password', 'token',
'api', 'delete', 'drop', 'truncate', 'insert',
'update', 'select', 'union', 'script', 'alert'
]
path_lower = request.path.lower()
query_string = request.query_string.decode('utf-8', errors='ignore').lower()
# Check for SQL injection attempts
sql_patterns = ['union', 'select', 'insert', 'delete', 'drop', 'truncate', '--', ';']
if any(pattern in query_string for pattern in sql_patterns):
security_logger.logger.warning(
f"POTENTIAL_SQL_INJECTION: Suspicious query string detected",
path=request.path,
query_string=request.query_string.decode('utf-8', errors='ignore'),
client_ip=g.client_ip,
user_agent=g.user_agent,
user_id=self.get_user_id()
)
# Check for XSS attempts
xss_patterns = ['<script', 'javascript:', 'onerror=', 'onload=', 'alert(']
if any(pattern in query_string for pattern in xss_patterns):
security_logger.logger.warning(
f"POTENTIAL_XSS: Suspicious script content detected",
path=request.path,
query_string=request.query_string.decode('utf-8', errors='ignore'),
client_ip=g.client_ip,
user_agent=g.user_agent,
user_id=self.get_user_id()
)
# Monitor failed authentication attempts
if request.endpoint in ['auth.login', 'auth.register'] and request.method == 'POST':
g.is_auth_attempt = True
# Check for brute force attempts (multiple requests from same IP)
if hasattr(g, 'is_auth_attempt'):
self.check_brute_force_attempt()
def check_brute_force_attempt(self):
"""Check for potential brute force attacks."""
# This would typically use Redis or database to track attempts
# For now, just log the attempt
security_logger.logger.info(
f"AUTH_ATTEMPT: Authentication attempt from {g.client_ip}",
client_ip=g.client_ip,
user_agent=g.user_agent,
endpoint=request.endpoint,
timestamp=datetime.now(timezone.utc).isoformat()
)
class DatabaseLoggingMiddleware:
"""Middleware for database query logging."""
def __init__(self, app=None):
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""Initialize database logging."""
from sqlalchemy import event
from sqlalchemy.engine import Engine
@event.listens_for(Engine, "before_cursor_execute")
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()
@event.listens_for(Engine, "after_cursor_execute")
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
total = time.time() - context._query_start_time
duration_ms = total * 1000
# Extract table name from query
table_name = self.extract_table_name(statement)
query_type = statement.strip().split()[0].upper() if statement.strip() else 'UNKNOWN'
# Log slow queries (> 100ms)
if duration_ms > 100:
performance_logger.logger.warning(
f"SLOW_QUERY: {query_type} on {table_name} took {duration_ms:.2f}ms",
query_type=query_type,
table=table_name,
duration_ms=duration_ms,
statement=statement[:200] + "..." if len(statement) > 200 else statement
)
# Log all database queries
performance_logger.log_database_query(
query_type=query_type,
table=table_name,
duration_ms=duration_ms
)
def extract_table_name(self, statement):
"""Extract table name from SQL statement."""
try:
statement_upper = statement.upper()
if 'FROM ' in statement_upper:
parts = statement_upper.split('FROM ')[1].split()
return parts[0].strip('`"[]') if parts else 'unknown'
elif 'INTO ' in statement_upper:
parts = statement_upper.split('INTO ')[1].split()
return parts[0].strip('`"[]') if parts else 'unknown'
elif 'UPDATE ' in statement_upper:
parts = statement_upper.split('UPDATE ')[1].split()
return parts[0].strip('`"[]') if parts else 'unknown'
else:
return 'unknown'
except:
return 'unknown'
class SecurityMiddleware:
"""Specialized middleware for security monitoring."""
def __init__(self, app=None):
self.app = app
if app is not None:
self.init_app(app)
def init_app(self, app):
"""Initialize security middleware."""
app.before_request(self.security_check)
def security_check(self):
"""Perform security checks on each request."""
# Skip security checks for static files
if request.endpoint == 'static' or request.path.startswith('/static/'):
return
# Rate limiting check (basic implementation)
self.check_rate_limiting()
# Check for banned IPs or user agents
self.check_blacklist()
# Monitor for admin access
self.monitor_admin_access()
def check_rate_limiting(self):
"""Basic rate limiting check."""
# This would typically use Redis for distributed rate limiting
client_ip = request.environ.get('REMOTE_ADDR', 'Unknown')
# For now, just log high-frequency requests
if hasattr(g, 'request_count'):
g.request_count += 1
else:
g.request_count = 1
def check_blacklist(self):
"""Check if IP or user agent is blacklisted."""
client_ip = request.environ.get('REMOTE_ADDR', 'Unknown')
user_agent = request.headers.get('User-Agent', 'Unknown')
# Example blacklist (would be configurable)
blacklisted_ips = []
suspicious_agents = ['sqlmap', 'nikto', 'nmap', 'masscan']
if client_ip in blacklisted_ips:
security_logger.logger.critical(
f"BLACKLISTED_IP: Access attempt from blacklisted IP",
client_ip=client_ip,
path=request.path,
user_agent=user_agent
)
if any(agent in user_agent.lower() for agent in suspicious_agents):
security_logger.logger.warning(
f"SUSPICIOUS_AGENT: Request with suspicious user agent",
client_ip=client_ip,
user_agent=user_agent,
path=request.path
)
def monitor_admin_access(self):
"""Monitor access to admin endpoints."""
admin_endpoints = ['/admin', '/management', '/config', '/settings']
if any(endpoint in request.path for endpoint in admin_endpoints):
try:
user_id = current_user.id if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated else None
except:
user_id = None
security_logger.logger.info(
f"ADMIN_ACCESS: Access to admin endpoint",
path=request.path,
client_ip=request.environ.get('REMOTE_ADDR', 'Unknown'),
user_agent=request.headers.get('User-Agent', 'Unknown'),
user_id=user_id
)