You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
361 lines
14 KiB
361 lines
14 KiB
"""
|
|
Flask middleware for request/response logging and monitoring.
|
|
|
|
This module provides middleware to automatically log HTTP requests, responses,
|
|
performance metrics, and security events.
|
|
"""
|
|
|
|
import time
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from flask import request, g, current_app
|
|
from flask_login import current_user
|
|
from logging_config import (
|
|
performance_logger, security_logger, app_logger,
|
|
set_correlation_id, get_correlation_id, log_context
|
|
)
|
|
|
|
class RequestLoggingMiddleware:
|
|
"""Middleware for comprehensive request/response logging."""
|
|
|
|
def __init__(self, app=None):
|
|
self.app = app
|
|
if app is not None:
|
|
self.init_app(app)
|
|
|
|
def init_app(self, app):
|
|
"""Initialize the middleware with Flask app."""
|
|
app.before_request(self.before_request)
|
|
app.after_request(self.after_request)
|
|
app.teardown_appcontext(self.teardown_request)
|
|
|
|
def before_request(self):
|
|
"""Called before each request."""
|
|
# Skip logging for static files
|
|
if request.endpoint == 'static' or request.path.startswith('/static/'):
|
|
return
|
|
|
|
# Set correlation ID for request tracking
|
|
correlation_id = request.headers.get('X-Correlation-ID') or set_correlation_id()
|
|
g.correlation_id = correlation_id
|
|
g.start_time = time.time()
|
|
|
|
# Get client information
|
|
g.client_ip = self.get_client_ip()
|
|
g.user_agent = request.headers.get('User-Agent', 'Unknown')
|
|
|
|
# Log request details
|
|
self.log_request_start()
|
|
|
|
# Security monitoring for sensitive endpoints
|
|
self.check_security_events()
|
|
|
|
def after_request(self, response):
|
|
"""Called after each request."""
|
|
# Skip logging for static files
|
|
if request.endpoint == 'static' or request.path.startswith('/static/'):
|
|
return response
|
|
|
|
if hasattr(g, 'start_time'):
|
|
duration_ms = (time.time() - g.start_time) * 1000
|
|
|
|
# Log response details
|
|
self.log_request_complete(response, duration_ms)
|
|
|
|
# Log performance metrics
|
|
self.log_performance_metrics(response, duration_ms)
|
|
|
|
# Add correlation ID to response headers
|
|
response.headers['X-Correlation-ID'] = get_correlation_id()
|
|
|
|
return response
|
|
|
|
def teardown_request(self, exception=None):
|
|
"""Called when request context is torn down."""
|
|
try:
|
|
# Skip logging for static files
|
|
if request.endpoint == 'static' or request.path.startswith('/static/'):
|
|
return
|
|
|
|
if exception:
|
|
app_logger.error(
|
|
f"Request failed with exception: {str(exception)}",
|
|
endpoint=request.endpoint,
|
|
method=request.method,
|
|
path=request.path,
|
|
exception_type=type(exception).__name__,
|
|
user_id=self.get_user_id()
|
|
)
|
|
except RuntimeError:
|
|
# Request context is no longer available, skip logging
|
|
pass
|
|
except Exception as e:
|
|
# Don't let logging errors break the teardown
|
|
pass
|
|
|
|
def get_client_ip(self):
|
|
"""Get the real client IP address."""
|
|
# Check for forwarded headers (reverse proxy)
|
|
if request.headers.get('X-Forwarded-For'):
|
|
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
|
elif request.headers.get('X-Real-IP'):
|
|
return request.headers.get('X-Real-IP')
|
|
else:
|
|
return request.environ.get('REMOTE_ADDR', 'Unknown')
|
|
|
|
def get_user_id(self):
|
|
"""Get current user ID if authenticated."""
|
|
try:
|
|
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
|
|
return current_user.id
|
|
return None
|
|
except:
|
|
return None
|
|
|
|
def log_request_start(self):
|
|
"""Log the start of a request."""
|
|
app_logger.info(
|
|
f"REQUEST_START: {request.method} {request.path}",
|
|
method=request.method,
|
|
path=request.path,
|
|
endpoint=request.endpoint,
|
|
client_ip=g.client_ip,
|
|
user_agent=g.user_agent,
|
|
user_id=self.get_user_id(),
|
|
query_params=dict(request.args) if request.args else None
|
|
)
|
|
|
|
def log_request_complete(self, response, duration_ms):
|
|
"""Log the completion of a request."""
|
|
app_logger.info(
|
|
f"REQUEST_COMPLETE: {request.method} {request.path} - {response.status_code} - {duration_ms:.2f}ms",
|
|
method=request.method,
|
|
path=request.path,
|
|
endpoint=request.endpoint,
|
|
status_code=response.status_code,
|
|
duration_ms=duration_ms,
|
|
response_size=len(response.get_data()) if hasattr(response, 'get_data') else None,
|
|
client_ip=g.client_ip,
|
|
user_id=self.get_user_id()
|
|
)
|
|
|
|
def log_performance_metrics(self, response, duration_ms):
|
|
"""Log performance metrics for the request."""
|
|
# Log slow requests (> 1 second)
|
|
if duration_ms > 1000:
|
|
performance_logger.logger.warning(
|
|
f"SLOW_REQUEST: {request.method} {request.path} took {duration_ms:.2f}ms",
|
|
method=request.method,
|
|
path=request.path,
|
|
endpoint=request.endpoint,
|
|
duration_ms=duration_ms,
|
|
status_code=response.status_code,
|
|
user_id=self.get_user_id()
|
|
)
|
|
|
|
# Log to performance logger
|
|
performance_logger.log_request_time(
|
|
endpoint=request.endpoint or request.path,
|
|
method=request.method,
|
|
duration_ms=duration_ms,
|
|
status_code=response.status_code,
|
|
user_id=self.get_user_id()
|
|
)
|
|
|
|
def check_security_events(self):
|
|
"""Check for potential security events."""
|
|
# Monitor for suspicious patterns
|
|
suspicious_patterns = [
|
|
'admin', 'login', 'auth', 'password', 'token',
|
|
'api', 'delete', 'drop', 'truncate', 'insert',
|
|
'update', 'select', 'union', 'script', 'alert'
|
|
]
|
|
|
|
path_lower = request.path.lower()
|
|
query_string = request.query_string.decode('utf-8', errors='ignore').lower()
|
|
|
|
# Check for SQL injection attempts
|
|
sql_patterns = ['union', 'select', 'insert', 'delete', 'drop', 'truncate', '--', ';']
|
|
if any(pattern in query_string for pattern in sql_patterns):
|
|
security_logger.logger.warning(
|
|
f"POTENTIAL_SQL_INJECTION: Suspicious query string detected",
|
|
path=request.path,
|
|
query_string=request.query_string.decode('utf-8', errors='ignore'),
|
|
client_ip=g.client_ip,
|
|
user_agent=g.user_agent,
|
|
user_id=self.get_user_id()
|
|
)
|
|
|
|
# Check for XSS attempts
|
|
xss_patterns = ['<script', 'javascript:', 'onerror=', 'onload=', 'alert(']
|
|
if any(pattern in query_string for pattern in xss_patterns):
|
|
security_logger.logger.warning(
|
|
f"POTENTIAL_XSS: Suspicious script content detected",
|
|
path=request.path,
|
|
query_string=request.query_string.decode('utf-8', errors='ignore'),
|
|
client_ip=g.client_ip,
|
|
user_agent=g.user_agent,
|
|
user_id=self.get_user_id()
|
|
)
|
|
|
|
# Monitor failed authentication attempts
|
|
if request.endpoint in ['auth.login', 'auth.register'] and request.method == 'POST':
|
|
g.is_auth_attempt = True
|
|
|
|
# Check for brute force attempts (multiple requests from same IP)
|
|
if hasattr(g, 'is_auth_attempt'):
|
|
self.check_brute_force_attempt()
|
|
|
|
def check_brute_force_attempt(self):
|
|
"""Check for potential brute force attacks."""
|
|
# This would typically use Redis or database to track attempts
|
|
# For now, just log the attempt
|
|
security_logger.logger.info(
|
|
f"AUTH_ATTEMPT: Authentication attempt from {g.client_ip}",
|
|
client_ip=g.client_ip,
|
|
user_agent=g.user_agent,
|
|
endpoint=request.endpoint,
|
|
timestamp=datetime.now(timezone.utc).isoformat()
|
|
)
|
|
|
|
class DatabaseLoggingMiddleware:
|
|
"""Middleware for database query logging."""
|
|
|
|
def __init__(self, app=None):
|
|
self.app = app
|
|
if app is not None:
|
|
self.init_app(app)
|
|
|
|
def init_app(self, app):
|
|
"""Initialize database logging."""
|
|
from sqlalchemy import event
|
|
from sqlalchemy.engine import Engine
|
|
|
|
@event.listens_for(Engine, "before_cursor_execute")
|
|
def receive_before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
context._query_start_time = time.time()
|
|
|
|
@event.listens_for(Engine, "after_cursor_execute")
|
|
def receive_after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
|
total = time.time() - context._query_start_time
|
|
duration_ms = total * 1000
|
|
|
|
# Extract table name from query
|
|
table_name = self.extract_table_name(statement)
|
|
query_type = statement.strip().split()[0].upper() if statement.strip() else 'UNKNOWN'
|
|
|
|
# Log slow queries (> 100ms)
|
|
if duration_ms > 100:
|
|
performance_logger.logger.warning(
|
|
f"SLOW_QUERY: {query_type} on {table_name} took {duration_ms:.2f}ms",
|
|
query_type=query_type,
|
|
table=table_name,
|
|
duration_ms=duration_ms,
|
|
statement=statement[:200] + "..." if len(statement) > 200 else statement
|
|
)
|
|
|
|
# Log all database queries
|
|
performance_logger.log_database_query(
|
|
query_type=query_type,
|
|
table=table_name,
|
|
duration_ms=duration_ms
|
|
)
|
|
|
|
def extract_table_name(self, statement):
|
|
"""Extract table name from SQL statement."""
|
|
try:
|
|
statement_upper = statement.upper()
|
|
if 'FROM ' in statement_upper:
|
|
parts = statement_upper.split('FROM ')[1].split()
|
|
return parts[0].strip('`"[]') if parts else 'unknown'
|
|
elif 'INTO ' in statement_upper:
|
|
parts = statement_upper.split('INTO ')[1].split()
|
|
return parts[0].strip('`"[]') if parts else 'unknown'
|
|
elif 'UPDATE ' in statement_upper:
|
|
parts = statement_upper.split('UPDATE ')[1].split()
|
|
return parts[0].strip('`"[]') if parts else 'unknown'
|
|
else:
|
|
return 'unknown'
|
|
except:
|
|
return 'unknown'
|
|
|
|
class SecurityMiddleware:
|
|
"""Specialized middleware for security monitoring."""
|
|
|
|
def __init__(self, app=None):
|
|
self.app = app
|
|
if app is not None:
|
|
self.init_app(app)
|
|
|
|
def init_app(self, app):
|
|
"""Initialize security middleware."""
|
|
app.before_request(self.security_check)
|
|
|
|
def security_check(self):
|
|
"""Perform security checks on each request."""
|
|
# Skip security checks for static files
|
|
if request.endpoint == 'static' or request.path.startswith('/static/'):
|
|
return
|
|
|
|
# Rate limiting check (basic implementation)
|
|
self.check_rate_limiting()
|
|
|
|
# Check for banned IPs or user agents
|
|
self.check_blacklist()
|
|
|
|
# Monitor for admin access
|
|
self.monitor_admin_access()
|
|
|
|
def check_rate_limiting(self):
|
|
"""Basic rate limiting check."""
|
|
# This would typically use Redis for distributed rate limiting
|
|
client_ip = request.environ.get('REMOTE_ADDR', 'Unknown')
|
|
|
|
# For now, just log high-frequency requests
|
|
if hasattr(g, 'request_count'):
|
|
g.request_count += 1
|
|
else:
|
|
g.request_count = 1
|
|
|
|
def check_blacklist(self):
|
|
"""Check if IP or user agent is blacklisted."""
|
|
client_ip = request.environ.get('REMOTE_ADDR', 'Unknown')
|
|
user_agent = request.headers.get('User-Agent', 'Unknown')
|
|
|
|
# Example blacklist (would be configurable)
|
|
blacklisted_ips = []
|
|
suspicious_agents = ['sqlmap', 'nikto', 'nmap', 'masscan']
|
|
|
|
if client_ip in blacklisted_ips:
|
|
security_logger.logger.critical(
|
|
f"BLACKLISTED_IP: Access attempt from blacklisted IP",
|
|
client_ip=client_ip,
|
|
path=request.path,
|
|
user_agent=user_agent
|
|
)
|
|
|
|
if any(agent in user_agent.lower() for agent in suspicious_agents):
|
|
security_logger.logger.warning(
|
|
f"SUSPICIOUS_AGENT: Request with suspicious user agent",
|
|
client_ip=client_ip,
|
|
user_agent=user_agent,
|
|
path=request.path
|
|
)
|
|
|
|
def monitor_admin_access(self):
|
|
"""Monitor access to admin endpoints."""
|
|
admin_endpoints = ['/admin', '/management', '/config', '/settings']
|
|
|
|
if any(endpoint in request.path for endpoint in admin_endpoints):
|
|
try:
|
|
user_id = current_user.id if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated else None
|
|
except:
|
|
user_id = None
|
|
|
|
security_logger.logger.info(
|
|
f"ADMIN_ACCESS: Access to admin endpoint",
|
|
path=request.path,
|
|
client_ip=request.environ.get('REMOTE_ADDR', 'Unknown'),
|
|
user_agent=request.headers.get('User-Agent', 'Unknown'),
|
|
user_id=user_id
|
|
)
|