Files
users/src/presentation/middleware/trace_id.py

136 lines
5.0 KiB
Python

from __future__ import annotations
from typing import Optional
from contextvars import Token
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from ulid import ULID
from src.application.contracts import ILogger
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
class TraceIDMiddleware:
def __init__(
self,
app: ASGIApp,
logger: ILogger,
response_header_name: str = "X-Trace-ID",
attach_response_header: bool = True,
) -> None:
self.app = app
self.logger = logger
self.response_header_name = response_header_name
self.attach_response_header = attach_response_header
def _is_excluded(self, path: str) -> bool:
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope)
if self._is_excluded(request.url.path):
await self.app(scope, receive, send)
return
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
if not trace_id:
trace_id = str(ULID())
request.state.trace_id = trace_id
token: Token = trace_id_var.set(trace_id)
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
status_code_holder: dict[str, Optional[int]] = {"status": None}
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
status_code_holder["status"] = int(message["status"])
if self.attach_response_header:
headers = list(message.get("headers", []))
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
message["headers"] = headers
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
status = status_code_holder["status"]
status_part = f"{status}" if status is not None else "unknown"
self.logger.debug(
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
)
trace_id_var.reset(token)
# from __future__ import annotations
# from typing import Optional
# from starlette.requests import Request
# from starlette.types import ASGIApp, Message, Receive, Scope, Send
# from ulid import ULID
# from src.application.contracts import ILogger
# from src.infrastructure.config.settings import settings
#
#
# class TraceIDMiddleware:
# def __init__(
# self,
# app: ASGIApp,
# logger: ILogger,
# response_header_name: str = 'X-Trace-ID',
# attach_response_header: bool = True,
# ) -> None:
# self.app = app
# self.logger = logger
# self.response_header_name = response_header_name
# self.attach_response_header = attach_response_header
#
# def _is_excluded(self, path: str) -> bool:
# return any(path == p or path.startswith(p.rstrip('/') + '/') for p in settings.EXCLUDED_PATHS)
#
# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# if scope['type'] != 'http':
# await self.app(scope, receive, send)
# return
#
# request = Request(scope)
#
# if self._is_excluded(request.url.path):
# await self.app(scope, receive, send)
# return
#
# trace_id = request.headers.get('X-Trace-ID') or request.headers.get('X-Request-ID')
# if not trace_id:
# trace_id = str(ULID())
#
# request.state.trace_id = trace_id
# self.logger.set_trace_id(trace_id)
#
# self.logger.debug(f'Request started: {request.method} {request.url} - TraceID: {trace_id}')
#
# status_code_holder: dict[str, Optional[int]] = {'status': None}
#
# async def send_wrapper(message: Message) -> None:
# if message['type'] == 'http.response.start':
# status_code_holder['status'] = int(message['status'])
#
# if self.attach_response_header:
# headers = list(message.get('headers', []))
# headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
# message['headers'] = headers
# await send(message)
#
# try:
# await self.app(scope, receive, send_wrapper)
# finally:
# status = status_code_holder['status']
# status_part = f'{status}' if status is not None else 'unknown'
# self.logger.debug(f'Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}')
# self.logger.clear_trace_id()