FastAPI 中間件應用場景
本章節將探討 FastAPI 中間件的常見應用場景,並提供實用的實例代碼,幫助開發者理解如何在實際項目中有效利用中間件。
請求日誌記錄中間件
日誌記錄是中間件最常見的應用場景之一。通過中間件,我們可以統一記錄所有請求的詳細信息,包括請求方法、路徑、處理時間、狀態碼等。
import time
import logging
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
# 配置日誌
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("app")
app = FastAPI()
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 記錄請求開始時間
start_time = time.time()
# 收集請求信息
method = request.method
path = request.url.path
query_params = dict(request.query_params)
client_host = request.client.host if request.client else "unknown"
# 記錄請求開始
logger.info(f"Request started: {method} {path} from {client_host} with params {query_params}")
try:
# 處理請求
response = await call_next(request)
# 計算處理時間
process_time = time.time() - start_time
# 記錄成功的請求
logger.info(
f"Request completed: {method} {path} - Status: {response.status_code} - "
f"Duration: {process_time:.4f}s"
)
return response
except Exception as e:
# 記錄失敗的請求
process_time = time.time() - start_time
logger.error(
f"Request failed: {method} {path} - Error: {str(e)} - "
f"Duration: {process_time:.4f}s"
)
raise
app.add_middleware(LoggingMiddleware)
身份驗證與授權中間件
中間件是實現身份驗證和授權邏輯的理想位置,可以在請求到達路由處理函數之前驗證用戶身份。
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import jwt
from typing import List, Optional
app = FastAPI()
class AuthMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
secret_key: str,
exclude_paths: List[str] = None,
algorithm: str = "HS256"
):
super().__init__(app)
self.secret_key = secret_key
self.exclude_paths = exclude_paths or ["/login", "/docs", "/openapi.json"]
self.algorithm = algorithm
async def dispatch(self, request: Request, call_next):
# 檢查是否為排除路徑
if request.url.path in self.exclude_paths:
return await call_next(request)
# 獲取授權頭
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
return JSONResponse(
status_code=401,
content={"detail": "Missing or invalid authentication token"}
)
token = auth_header.split(" ")[1]
try:
# 驗證 JWT 令牌
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
# 將用戶信息添加到請求狀態
request.state.user = payload
request.state.user_id = payload.get("sub")
request.state.user_role = payload.get("role")
# 繼續處理請求
return await call_next(request)
except jwt.ExpiredSignatureError:
return JSONResponse(
status_code=401,
content={"detail": "Token has expired"}
)
except jwt.InvalidTokenError:
return JSONResponse(
status_code=401,
content={"detail": "Invalid authentication token"}
)
# 註冊中間件
app.add_middleware(
AuthMiddleware,
secret_key="your-secret-key",
exclude_paths=["/login", "/register", "/docs", "/openapi.json"]
)
# 在路由中使用請求狀態中的用戶信息
@app.get("/profile")
async def get_profile(request: Request):
user = request.state.user
return {"user_id": user.get("sub"), "username": user.get("username")}
CORS 處理中間件
跨域資源共享 (CORS) 是 Web 應用中常見的需求,FastAPI 提供了內置的 CORS 中間件:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# 配置 CORS
app.add_middleware(
CORSMiddleware,
# 允許的源列表
allow_origins=[
"http://localhost:3000",
"https://frontend.example.com"
],
# 是否允許發送憑證(如 cookies)
allow_credentials=True,
# 允許的 HTTP 方法
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
# 允許的 HTTP 頭
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
# 允許瀏覽器緩存預檢請求的時間(秒)
max_age=600,
)
請求限流中間件
為了防止 API 被濫用,可以實現請求限流中間件,限制特定時間內的請求次數:
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import time
from collections import defaultdict
app = FastAPI()
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
limit: int = 10,
window: int = 60,
exclude_paths: list = None
):
super().__init__(app)
self.limit = limit # 每個窗口允許的最大請求數
self.window = window # 窗口大小(秒)
self.exclude_paths = exclude_paths or []
self.requests = defaultdict(list) # 儲存每個 IP 的請求時間
async def dispatch(self, request: Request, call_next):
# 檢查是否為排除路徑
if request.url.path in self.exclude_paths:
return await call_next(request)
# 獲取客戶端 IP
client_ip = request.client.host if request.client else "unknown"
# 當前時間
current_time = time.time()
# 清理過期的請求記錄
self.requests[client_ip] = [
req_time for req_time in self.requests[client_ip]
if current_time - req_time < self.window
]
# 檢查是否超過限制
if len(self.requests[client_ip]) >= self.limit:
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests",
"retry_after": self.window - (current_time - self.requests[client_ip][0])
}
)
# 記錄當前請求
self.requests[client_ip].append(current_time)
# 處理請求
return await call_next(request)
# 註冊中間件
app.add_middleware(
RateLimitMiddleware,
limit=100, # 每分鐘 100 個請求
window=60, # 1 分鐘窗口
exclude_paths=["/docs", "/openapi.json"]
)
響應壓縮中間件
對於大型響應,可以使用壓縮中間件減少傳輸數據量,提高性能:
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
app = FastAPI()
# 註冊 GZip 壓縮中間件
app.add_middleware(
GZipMiddleware,
minimum_size=1000 # 僅壓縮大於 1000 字節的響應
)
全局異常處理中間件
中間件可以捕獲應用中的所有異常,提供統一的錯誤處理機制:
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
import traceback
import logging
logger = logging.getLogger("app")
app = FastAPI()
class ExceptionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)
except HTTPException as e:
# 處理 FastAPI 的 HTTPException
logger.warning(f"HTTP Exception: {e.detail} (status_code={e.status_code})")
return JSONResponse(
status_code=e.status_code,
content={"detail": e.detail}
)
except Exception as e:
# 處理未捕獲的異常
error_id = str(uuid.uuid4())
# 記錄詳細錯誤信息
logger.error(
f"Unhandled exception: {str(e)} (error_id={error_id})\n"
f"{traceback.format_exc()}"
)
# 返回用戶友好的錯誤信息
return JSONResponse(
status_code=500,
content={
"detail": "An unexpected error occurred",
"error_id": error_id
}
)
app.add_middleware(ExceptionMiddleware)
性能監控中間件
監控 API 端點的性能,識別可能的瓶頸:
import time
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import statistics
app = FastAPI()
class PerformanceMonitoringMiddleware(BaseHTTPMiddleware):
def __init__(self, app, threshold_ms: float = 500):
super().__init__(app)
self.threshold_ms = threshold_ms
self.request_times = {} # 儲存每個路徑的處理時間
async def dispatch(self, request: Request, call_next):
start_time = time.time()
response = await call_next(request)
# 計算處理時間(毫秒)
process_time = (time.time() - start_time) * 1000
# 獲取請求路徑
path = request.url.path
# 更新路徑的處理時間統計
if path not in self.request_times:
self.request_times[path] = []
self.request_times[path].append(process_time)
# 保留最近 100 個請求的數據
if len(self.request_times[path]) > 100:
self.request_times[path].pop(0)
# 如果處理時間超過閾值,記錄警告
if process_time > self.threshold_ms:
print(f"WARNING: Slow request detected - {request.method} {path} took {process_time:.2f}ms")
# 每 100 個請求計算一次統計數據
if len(self.request_times[path]) % 100 == 0:
times = self.request_times[path]
avg_time = statistics.mean(times)
p95_time = sorted(times)[int(len(times) * 0.95)]
print(f"Performance stats for {path}:")
print(f" Average: {avg_time:.2f}ms")
print(f" 95th percentile: {p95_time:.2f}ms")
print(f" Min: {min(times):.2f}ms")
print(f" Max: {max(times):.2f}ms")
# 將處理時間添加到響應頭
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
return response
app.add_middleware(PerformanceMonitoringMiddleware, threshold_ms=200)