171 lines
5.6 KiB
Python
171 lines
5.6 KiB
Python
from __future__ import annotations
|
|
import functools
|
|
import inspect
|
|
import hashlib
|
|
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
|
|
from fastapi import Request
|
|
from redis.asyncio.client import Redis
|
|
from src.application.contracts import ILogger
|
|
from src.application.domain.exceptions import ApplicationException
|
|
from src.infrastructure.logger import get_logger
|
|
from src.presentation.dependencies import get_redis
|
|
|
|
|
|
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
|
|
req = kwargs.get('request')
|
|
if isinstance(req, Request):
|
|
return req
|
|
for a in args:
|
|
if isinstance(a, Request):
|
|
return a
|
|
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
|
|
|
|
|
|
def _client_ip(request: Request) -> str:
|
|
xff = request.headers.get('x-forwarded-for')
|
|
if xff:
|
|
return xff.split(',')[0].strip()
|
|
if request.client:
|
|
return request.client.host
|
|
return 'unknown'
|
|
|
|
|
|
_LUA_INCR_EXPIRE_TTL = '''
|
|
local key = KEYS[1]
|
|
local window = tonumber(ARGV[1])
|
|
|
|
local current = redis.call('INCR', key)
|
|
if current == 1 then
|
|
redis.call('EXPIRE', key, window)
|
|
end
|
|
|
|
local ttl = redis.call('TTL', key)
|
|
return { current, ttl }
|
|
'''
|
|
|
|
|
|
Scope = Literal['ip', 'device', 'user', 'key']
|
|
|
|
|
|
@runtime_checkable
|
|
class KeyBuilder1(Protocol):
|
|
def __call__(self, request: Request) -> str: ...
|
|
|
|
|
|
@runtime_checkable
|
|
class KeyBuilder3(Protocol):
|
|
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
|
|
|
|
|
|
KeyBuilder = KeyBuilder1 | KeyBuilder3
|
|
|
|
|
|
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
|
try:
|
|
sig = inspect.signature(builder)
|
|
if len(sig.parameters) >= 3:
|
|
return builder(request, args, kwargs)
|
|
return builder(request)
|
|
except Exception as e:
|
|
try:
|
|
return builder(request, args, kwargs)
|
|
except Exception:
|
|
raise e
|
|
|
|
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
|
|
|
body = kwargs.get('body')
|
|
if body is None and args:
|
|
for a in args:
|
|
if hasattr(a, 'email'):
|
|
body = a
|
|
break
|
|
|
|
email = (getattr(body, 'email', '') or '').strip().lower()
|
|
if not email:
|
|
email = _client_ip(request)
|
|
|
|
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
|
|
return f'email:{digest}'
|
|
|
|
def rate_limit(
|
|
*,
|
|
limit: int,
|
|
window_seconds: int,
|
|
scope: Scope = 'ip',
|
|
key_prefix: str = 'rl',
|
|
key_builder: Optional[KeyBuilder] = None,
|
|
fail_open: bool = True,
|
|
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
|
|
|
if limit <= 0:
|
|
raise ValueError('rate_limit: limit must be > 0')
|
|
if window_seconds <= 0:
|
|
raise ValueError('rate_limit: window_seconds must be > 0')
|
|
if scope == 'key' and not key_builder:
|
|
raise ValueError('rate_limit: scope="key" requires key_builder')
|
|
|
|
def decorator(func: Callable[..., Awaitable[Any]]):
|
|
@functools.wraps(func)
|
|
async def wrapper(*args: Any, **kwargs: Any):
|
|
request = _find_request(args, kwargs)
|
|
logger: ILogger = get_logger()
|
|
|
|
if scope == 'ip':
|
|
ident = _client_ip(request)
|
|
elif scope == 'device':
|
|
ident = request.cookies.get('device_id') or _client_ip(request)
|
|
elif scope == 'user':
|
|
user = getattr(request.state, 'user', None)
|
|
user_id = getattr(user, 'id', None) if user else None
|
|
ident = str(user_id) if user_id else _client_ip(request)
|
|
else:
|
|
try:
|
|
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
|
|
except Exception as e:
|
|
logger.error(f'RateLimit key_builder failed error={str(e)}')
|
|
raise ApplicationException(500, 'Rate limiter key_builder failed')
|
|
|
|
route = request.url.path
|
|
method = request.method
|
|
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
|
|
|
|
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
|
|
|
|
try:
|
|
redis: Redis = get_redis(request)
|
|
|
|
result = await redis.eval(
|
|
_LUA_INCR_EXPIRE_TTL,
|
|
1,
|
|
redis_key,
|
|
str(window_seconds),
|
|
)
|
|
|
|
count = int(result[0])
|
|
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
|
|
ttl = window_seconds if ttl_raw < 0 else ttl_raw
|
|
|
|
except Exception as e:
|
|
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
|
|
|
|
if fail_open:
|
|
logger.warning(f'RateLimit fail-open activated key={redis_key}')
|
|
return await func(*args, **kwargs)
|
|
|
|
raise ApplicationException(503, 'Rate limiter unavailable')
|
|
|
|
if count > limit:
|
|
retry_after = max(ttl, 0)
|
|
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
|
|
raise ApplicationException(
|
|
status_code=429,
|
|
message='Too Many Requests',
|
|
headers={'Retry-After': str(retry_after)},
|
|
)
|
|
|
|
logger.debug(f'RateLimit passed key={redis_key} count={count}')
|
|
return await func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
return decorator |