feat(account): GET /me user endpoint only, disable cache and extra routers
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
171
src/presentation/decorators/rate_limit.py
Normal file
171
src/presentation/decorators/rate_limit.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import hashlib
|
||||
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
|
||||
from fastapi import Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies import get_redis
|
||||
|
||||
|
||||
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
|
||||
req = kwargs.get('request')
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
return a
|
||||
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
if xff:
|
||||
return xff.split(',')[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return 'unknown'
|
||||
|
||||
|
||||
_LUA_INCR_EXPIRE_TTL = '''
|
||||
local key = KEYS[1]
|
||||
local window = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, window)
|
||||
end
|
||||
|
||||
local ttl = redis.call('TTL', key)
|
||||
return { current, ttl }
|
||||
'''
|
||||
|
||||
|
||||
Scope = Literal['ip', 'device', 'user', 'key']
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder1(Protocol):
|
||||
def __call__(self, request: Request) -> str: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder3(Protocol):
|
||||
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
KeyBuilder = KeyBuilder1 | KeyBuilder3
|
||||
|
||||
|
||||
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
try:
|
||||
sig = inspect.signature(builder)
|
||||
if len(sig.parameters) >= 3:
|
||||
return builder(request, args, kwargs)
|
||||
return builder(request)
|
||||
except Exception as e:
|
||||
try:
|
||||
return builder(request, args, kwargs)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
|
||||
body = kwargs.get('body')
|
||||
if body is None and args:
|
||||
for a in args:
|
||||
if hasattr(a, 'email'):
|
||||
body = a
|
||||
break
|
||||
|
||||
email = (getattr(body, 'email', '') or '').strip().lower()
|
||||
if not email:
|
||||
email = _client_ip(request)
|
||||
|
||||
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
|
||||
return f'email:{digest}'
|
||||
|
||||
def rate_limit(
|
||||
*,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
scope: Scope = 'ip',
|
||||
key_prefix: str = 'rl',
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
fail_open: bool = True,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
||||
|
||||
if limit <= 0:
|
||||
raise ValueError('rate_limit: limit must be > 0')
|
||||
if window_seconds <= 0:
|
||||
raise ValueError('rate_limit: window_seconds must be > 0')
|
||||
if scope == 'key' and not key_builder:
|
||||
raise ValueError('rate_limit: scope="key" requires key_builder')
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any):
|
||||
request = _find_request(args, kwargs)
|
||||
logger: ILogger = get_logger()
|
||||
|
||||
if scope == 'ip':
|
||||
ident = _client_ip(request)
|
||||
elif scope == 'device':
|
||||
ident = request.cookies.get('device_id') or _client_ip(request)
|
||||
elif scope == 'user':
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = getattr(user, 'id', None) if user else None
|
||||
ident = str(user_id) if user_id else _client_ip(request)
|
||||
else:
|
||||
try:
|
||||
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit key_builder failed error={str(e)}')
|
||||
raise ApplicationException(500, 'Rate limiter key_builder failed')
|
||||
|
||||
route = request.url.path
|
||||
method = request.method
|
||||
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
|
||||
|
||||
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
|
||||
|
||||
try:
|
||||
redis: Redis = get_redis(request)
|
||||
|
||||
result = await redis.eval(
|
||||
_LUA_INCR_EXPIRE_TTL,
|
||||
1,
|
||||
redis_key,
|
||||
str(window_seconds),
|
||||
)
|
||||
|
||||
count = int(result[0])
|
||||
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
|
||||
ttl = window_seconds if ttl_raw < 0 else ttl_raw
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
|
||||
|
||||
if fail_open:
|
||||
logger.warning(f'RateLimit fail-open activated key={redis_key}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
raise ApplicationException(503, 'Rate limiter unavailable')
|
||||
|
||||
if count > limit:
|
||||
retry_after = max(ttl, 0)
|
||||
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
|
||||
raise ApplicationException(
|
||||
status_code=429,
|
||||
message='Too Many Requests',
|
||||
headers={'Retry-After': str(retry_after)},
|
||||
)
|
||||
|
||||
logger.debug(f'RateLimit passed key={redis_key} count={count}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
Reference in New Issue
Block a user