跳轉到

FastAPI 中間件進階技巧

本章節將深入探討 FastAPI 中間件的進階技巧,包括異步中間件優化、中間件與依賴注入的結合、中間件上下文管理、中間件測試策略以及其他高級應用場景。

異步中間件優化

FastAPI 建立在 ASGI 標準之上,充分利用 Python 的異步特性可以顯著提高應用性能。以下是一些異步中間件的優化技巧:

避免阻塞操作

在異步中間件中,應避免使用阻塞操作,如同步 I/O、長時間計算等:

from fastapi import FastAPI, Request
import aiohttp
import asyncio

app = FastAPI()

@app.middleware("http")
async def external_service_middleware(request: Request, call_next):
    # 錯誤示例:在異步中間件中使用阻塞操作
    # import requests
    # response = requests.get("https://api.example.com/data")  # 阻塞!

    # 正確示例:使用異步 HTTP 客戶端
    async with aiohttp.ClientSession() as session:
        async with session.get("https://api.example.com/data") as response:
            data = await response.json()

    # 將數據添加到請求狀態
    request.state.external_data = data

    # 繼續處理請求
    return await call_next(request)

使用異步上下文管理器

利用異步上下文管理器可以更優雅地處理資源獲取和釋放:

import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware

@asynccontextmanager
async def get_db_connection():
    # 假設這是一個異步數據庫連接
    conn = await create_async_connection()
    try:
        yield conn
    finally:
        await conn.close()

class DatabaseMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        async with get_db_connection() as conn:
            # 將數據庫連接添加到請求狀態
            request.state.db = conn

            # 處理請求
            response = await call_next(request)

            # 連接會在上下文管理器退出時自動關閉
            return response

app = FastAPI()
app.add_middleware(DatabaseMiddleware)

並行處理

在中間件中可以使用 asyncio.gather 並行執行多個異步任務:

@app.middleware("http")
async def parallel_tasks_middleware(request: Request, call_next):
    # 定義需要並行執行的任務
    async def fetch_user_data():
        # 假設這是從數據庫獲取用戶數據的異步操作
        await asyncio.sleep(0.1)  # 模擬 I/O 延遲
        return {"user_id": 123, "name": "John Doe"}

    async def fetch_metrics():
        # 假設這是從指標系統獲取數據的異步操作
        await asyncio.sleep(0.1)  # 模擬 I/O 延遲
        return {"api_calls": 1000, "error_rate": 0.01}

    # 並行執行任務
    user_data, metrics = await asyncio.gather(
        fetch_user_data(),
        fetch_metrics()
    )

    # 將結果存儲到請求狀態
    request.state.user_data = user_data
    request.state.metrics = metrics

    # 繼續處理請求
    return await call_next(request)

中間件與依賴注入結合

FastAPI 的中間件和依賴注入系統可以協同工作,實現更強大的功能:

在中間件中預處理依賴項

from fastapi import FastAPI, Request, Depends
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Dict, Optional

# 定義一個依賴項
async def get_current_user(request: Request) -> Optional[Dict]:
    auth_header = request.headers.get("Authorization")
    if not auth_header or not auth_header.startswith("Bearer "):
        return None

    token = auth_header.split(" ")[1]
    # 這裡省略了實際的令牌驗證邏輯
    return {"user_id": 123, "username": "john_doe"}

class AuthPreprocessingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # 預先解析用戶信息
        user = await get_current_user(request)

        # 將用戶信息存儲到請求狀態
        request.state.user = user

        # 處理請求
        response = await call_next(request)
        return response

app = FastAPI()
app.add_middleware(AuthPreprocessingMiddleware)

# 在路由中使用預處理的用戶信息
@app.get("/user/profile")
async def get_profile(request: Request):
    user = request.state.user
    if not user:
        return {"detail": "Not authenticated"}
    return {"profile": user}

自定義中間件依賴項

創建可在多個中間件之間共享的依賴項:

from fastapi import FastAPI, Request, Depends
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable, TypeVar, Generic, Optional

T = TypeVar('T')

class MiddlewareDependency(Generic[T]):
    def __init__(self, dependency: Callable[..., T]):
        self.dependency = dependency

    async def resolve(self, request: Request) -> T:
        return await self.dependency(request)

# 創建一個配置依賴項
async def get_app_config(request: Request):
    # 這裡可以是從數據庫或配置文件加載配置
    return {
        "feature_flags": {
            "new_ui": True,
            "beta_features": False
        },
        "limits": {
            "max_requests_per_minute": 100
        }
    }

config_dependency = MiddlewareDependency(get_app_config)

class ConfigMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # 解析配置依賴項
        config = await config_dependency.resolve(request)

        # 將配置存儲到請求狀態
        request.state.config = config

        # 處理請求
        response = await call_next(request)
        return response

app = FastAPI()
app.add_middleware(ConfigMiddleware)

# 在路由中使用配置
@app.get("/features")
async def get_features(request: Request):
    config = request.state.config
    return {"features": config["feature_flags"]}

中間件上下文管理

在複雜應用中,有效管理中間件上下文可以簡化代碼並提高可維護性:

請求上下文管理

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from contextvars import ContextVar
from typing import Optional, Dict, Any

# 定義上下文變量
request_id_var: ContextVar[str] = ContextVar("request_id", default=None)
user_var: ContextVar[Optional[Dict]] = ContextVar("user", default=None)

# 上下文管理器
class RequestContextMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # 生成請求 ID
        request_id = str(uuid.uuid4())

        # 設置上下文變量
        request_id_token = request_id_var.set(request_id)

        # 獲取用戶信息(如果有)
        auth_header = request.headers.get("Authorization")
        user = None
        if auth_header and auth_header.startswith("Bearer "):
            # 這裡省略了實際的令牌驗證邏輯
            user = {"user_id": 123, "username": "john_doe"}

        user_token = user_var.set(user)

        try:
            # 處理請求
            response = await call_next(request)

            # 添加請求 ID 到響應頭
            response.headers["X-Request-ID"] = request_id

            return response
        finally:
            # 重置上下文變量
            request_id_var.reset(request_id_token)
            user_var.reset(user_token)

app = FastAPI()
app.add_middleware(RequestContextMiddleware)

# 在任何地方獲取當前請求 ID
def get_current_request_id() -> str:
    return request_id_var.get()

# 在任何地方獲取當前用戶
def get_current_user() -> Optional[Dict]:
    return user_var.get()

# 使用上下文變量
@app.get("/context-demo")
async def context_demo():
    request_id = get_current_request_id()
    user = get_current_user()

    return {
        "request_id": request_id,
        "user": user
    }

分層中間件上下文

處理複雜的多層中間件場景:

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Dict, Any, List
import time

app = FastAPI()

class ContextualizedMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, name: str):
        super().__init__(app)
        self.name = name

    async def dispatch(self, request: Request, call_next):
        # 初始化上下文堆疊(如果不存在)
        if not hasattr(request.state, "middleware_context"):
            request.state.middleware_context = []
            request.state.timing_data = {}

        # 記錄進入中間件的時間
        start_time = time.time()

        # 將當前中間件添加到上下文堆疊
        context = {"name": self.name, "entered_at": start_time}
        request.state.middleware_context.append(context)

        try:
            # 處理請求
            response = await call_next(request)

            # 記錄退出中間件的時間
            end_time = time.time()
            elapsed = end_time - start_time

            # 更新計時數據
            request.state.timing_data[self.name] = elapsed

            return response
        finally:
            # 從上下文堆疊中移除當前中間件
            request.state.middleware_context.pop()

# 註冊多個中間件實例
app.add_middleware(ContextualizedMiddleware, name="outer")
app.add_middleware(ContextualizedMiddleware, name="middle")
app.add_middleware(ContextualizedMiddleware, name="inner")

@app.get("/context-stack")
async def get_context_stack(request: Request):
    return {
        "current_context": request.state.middleware_context,
        "timing_data": request.state.timing_data
    }

條件中間件執行

根據請求特性有條件地執行中間件邏輯:

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import re

app = FastAPI()

class ConditionalMiddleware(BaseHTTPMiddleware):
    def __init__(
        self, 
        app, 
        func,
        include_paths: List[str] = None,
        exclude_paths: List[str] = None,
        include_methods: List[str] = None
    ):
        super().__init__(app)
        self.func = func
        self.include_paths = [re.compile(path) for path in (include_paths or [])]
        self.exclude_paths = [re.compile(path) for path in (exclude_paths or [])]
        self.include_methods = [m.upper() for m in (include_methods or [])]

    def should_process(self, request: Request) -> bool:
        path = request.url.path
        method = request.method

        # 檢查排除路徑
        if any(pattern.match(path) for pattern in self.exclude_paths):
            return False

        # 檢查包含路徑
        if self.include_paths and not any(pattern.match(path) for pattern in self.include_paths):
            return False

        # 檢查包含方法
        if self.include_methods and method not in self.include_methods:
            return False

        return True

    async def dispatch(self, request: Request, call_next):
        if self.should_process(request):
            # 執行中間件邏輯
            return await self.func(request, call_next)
        else:
            # 跳過中間件邏輯
            return await call_next(request)

# 定義中間件邏輯
async def rate_limit_logic(request: Request, call_next):
    # 這裡實現限流邏輯
    print(f"Rate limiting applied to {request.url.path}")
    return await call_next(request)

# 註冊條件中間件
app.add_middleware(
    ConditionalMiddleware,
    func=rate_limit_logic,
    include_paths=[r"/api/.*"],  # 僅對 /api/ 開頭的路徑應用
    exclude_paths=[r"/api/public/.*"],  # 排除 /api/public/ 開頭的路徑
    include_methods=["POST", "PUT", "DELETE"]  # 僅對寫操作應用
)

中間件工廠模式

使用工廠模式創建可配置的中間件:

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable, Dict, Any, Optional
import json
import time

def create_logging_middleware(
    log_format: str = "default",
    include_headers: bool = False,
    include_body: bool = False,
    log_level: str = "info"
):
    """中間件工廠函數,創建可配置的日誌中間件"""

    class LoggingMiddleware(BaseHTTPMiddleware):
        async def dispatch(self, request: Request, call_next):
            # 記錄請求開始時間
            start_time = time.time()

            # 收集請求信息
            request_info = {
                "method": request.method,
                "path": request.url.path,
                "query_params": dict(request.query_params),
                "client_host": request.client.host if request.client else "unknown",
            }

            # 根據配置添加頭信息
            if include_headers:
                request_info["headers"] = dict(request.headers)

            # 根據配置添加請求體
            if include_body and request.method in ["POST", "PUT", "PATCH"]:
                try:
                    body = await request.body()
                    if body:
                        try:
                            # 嘗試解析 JSON
                            request_info["body"] = json.loads(body)
                        except json.JSONDecodeError:
                            # 如果不是 JSON,則保存原始字符串
                            request_info["body"] = body.decode("utf-8", errors="replace")
                except Exception:
                    pass

            # 記錄請求信息
            log_message = f"Request: {json.dumps(request_info)}"
            if log_level == "debug":
                print(f"DEBUG: {log_message}")
            else:
                print(f"INFO: {log_message}")

            # 處理請求
            response = await call_next(request)

            # 計算處理時間
            process_time = time.time() - start_time

            # 記錄響應信息
            response_info = {
                "status_code": response.status_code,
                "process_time": f"{process_time:.4f}s"
            }

            # 根據配置添加響應頭
            if include_headers:
                response_info["headers"] = dict(response.headers)

            # 記錄響應信息
            log_message = f"Response: {json.dumps(response_info)}"
            if log_level == "debug":
                print(f"DEBUG: {log_message}")
            else:
                print(f"INFO: {log_message}")

            return response

    return LoggingMiddleware

app = FastAPI()

# 使用工廠函數創建不同配置的中間件
app.add_middleware(create_logging_middleware(log_level="debug", include_headers=True))

# 可以為不同的應用創建不同配置的中間件
api_app = FastAPI()
api_app.add_middleware(create_logging_middleware(include_body=True, include_headers=True))

# 掛載子應用
app.mount("/api", api_app)

中間件測試策略

有效測試中間件是確保其正確性和穩定性的關鍵:

import pytest
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from starlette.middleware.base import BaseHTTPMiddleware

# 待測試的中間件
class HeaderMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, header_name: str, header_value: str):
        super().__init__(app)
        self.header_name = header_name
        self.header_value = header_value

    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        response.headers[self.header_name] = self.header_value
        return response

# 測試函數
def test_header_middleware():
    app = FastAPI()
    app.add_middleware(HeaderMiddleware, header_name="X-Test-Header", header_value="test-value")

    @app.get("/test")
    async def test_endpoint():
        return {"message": "test"}

    client = TestClient(app)
    response = client.get("/test")

    assert response.status_code == 200
    assert response.headers["X-Test-Header"] == "test-value"
    assert response.json() == {"message": "test"}

# 測試中間件執行順序
def test_middleware_order():
    app = FastAPI()
    execution_order = []

    class OrderMiddleware(BaseHTTPMiddleware):
        def __init__(self, app, name: str):
            super().__init__(app)
            self.name = name

        async def dispatch(self, request: Request, call_next):
            execution_order.append(f"{self.name}_before")
            response = await call_next(request)
            execution_order.append(f"{self.name}_after")
            return response

    app.add_middleware(OrderMiddleware, name="outer")
    app.add_middleware(OrderMiddleware, name="inner")

    @app.get("/test")
    async def test_endpoint():
        execution_order.append("endpoint")
        return {"message": "test"}

    client = TestClient(app)
    client.get("/test")

    # 驗證執行順序
    assert execution_order == [
        "outer_before",
        "inner_before",
        "endpoint",
        "inner_after",
        "outer_after"
    ]

動態中間件註冊與管理

在運行時動態管理中間件:

from fastapi import FastAPI, Request, Depends
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Dict, List, Type, Optional
import inspect

app = FastAPI()

# 中間件註冊表
middleware_registry: Dict[str, Type[BaseHTTPMiddleware]] = {}

# 動態中間件管理器
class DynamicMiddlewareManager:
    def __init__(self, app: FastAPI):
        self.app = app
        self.active_middleware: Dict[str, BaseHTTPMiddleware] = {}

    def register_middleware_class(self, name: str, middleware_class: Type[BaseHTTPMiddleware]):
        """註冊中間件類到註冊表"""
        middleware_registry[name] = middleware_class

    def activate_middleware(self, name: str, **kwargs):
        """激活並配置中間件"""
        if name not in middleware_registry:
            raise ValueError(f"Middleware '{name}' not found in registry")

        if name in self.active_middleware:
            raise ValueError(f"Middleware '{name}' is already active")

        # 創建中間件實例
        middleware_class = middleware_registry[name]

        # 檢查參數是否匹配
        init_params = inspect.signature(middleware_class.__init__).parameters
        valid_params = {k: v for k, v in kwargs.items() if k in init_params}

        # 實例化中間件
        middleware_instance = middleware_class(self.app, **valid_params)

        # 添加到應用
        self.app.add_middleware(type(middleware_instance), **valid_params)

        # 記錄激活的中間件
        self.active_middleware[name] = middleware_instance

        return f"Middleware '{name}' activated with parameters: {valid_params}"

    def get_active_middleware(self) -> List[str]:
        """獲取所有激活的中間件"""
        return list(self.active_middleware.keys())

# 創建中間件管理器
middleware_manager = DynamicMiddlewareManager(app)

# 定義一些示例中間件
class HeaderInjectionMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, header_name: str, header_value: str):
        super().__init__(app)
        self.header_name = header_name
        self.header_value = header_value

    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        response.headers[self.header_name] = self.header_value
        return response

class RequestLimitMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, max_requests: int = 100):
        super().__init__(app)
        self.max_requests = max_requests
        self.request_count = 0

    async def dispatch(self, request: Request, call_next):
        self.request_count += 1
        if self.request_count > self.max_requests:
            return JSONResponse(
                status_code=429,
                content={"detail": "Too many requests"}
            )
        return await call_next(request)

# 註冊中間件類
middleware_manager.register_middleware_class("header_injection", HeaderInjectionMiddleware)
middleware_manager.register_middleware_class("request_limit", RequestLimitMiddleware)

# API 端點來管理中間件
@app.post("/admin/middleware/{name}/activate")
async def activate_middleware(name: str, params: Dict[str, Any]):
    return middleware_manager.activate_middleware(name, **params)

@app.get("/admin/middleware/active")
async def get_active_middleware():
    return {"active_middleware": middleware_manager.get_active_middleware()}