feat: add validation

This commit is contained in:
2026-05-19 22:29:02 +03:00
parent 666f2f67cb
commit caf7f003fa
6 changed files with 110 additions and 41 deletions

View File

@@ -1,18 +1,32 @@
import asyncio
from datetime import datetime, timezone, timedelta
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
class JwtRefreshCommand:
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
_LOCK_PREFIX = 'jwt:refresh:lock:'
_LOCK_TTL_SECONDS = 15
_LOCK_WAIT_ATTEMPTS = 40
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service
self._cache = cache
self._logger = logger
@transactional
@@ -25,6 +39,39 @@ class JwtRefreshCommand:
user_id = payload.sub
jti = payload.jti
lock_key = f'{self._LOCK_PREFIX}{sid}'
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
if not locked:
for _ in range(self._LOCK_WAIT_ATTEMPTS):
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
if await self._cache.get(lock_key) is None:
self._logger.info(f'Concurrent refresh skipped (sid={sid})')
raise RefreshConcurrentException()
raise ApplicationException(status_code=429, message='Refresh in progress')
try:
return await self._refresh_locked(
sid=sid,
user_id=user_id,
jti=jti,
now=now,
ip=ip,
user_agent=user_agent,
)
finally:
await self._cache.delete(lock_key)
async def _refresh_locked(
self,
*,
sid: str,
user_id: str,
jti: str,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
if sess is None:
raise ApplicationException(status_code=401, message='Session not found')
@@ -61,7 +108,8 @@ class JwtRefreshCommand:
)
if not rotated:
raise ApplicationException(status_code=401, message='Refresh already rotated')
self._logger.info(f'Refresh already rotated (sid={sid})')
raise RefreshConcurrentException()
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)

View File

@@ -7,3 +7,4 @@ from src.application.domain.exceptions.not_found_exception import NotFoundExcept
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException

View File

@@ -0,0 +1,10 @@
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class RefreshConcurrentException(ApplicationException):
def __init__(self) -> None:
super().__init__(
status_code=status.HTTP_200_OK,
message='Refresh already handled',
)

View File

@@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
from src.application.contracts import ILogger
from src.application.domain.exceptions import RefreshConcurrentException
from src.infrastructure.database.repositories import UserRepository, SessionRepository
@@ -20,8 +21,10 @@ class UnitOfWork(IUnitOfWork):
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(str(exc_val))
await self._session.rollback()
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(f'Rollback: str{exc_val})')
else:
await self._session.flush()

View File

@@ -93,6 +93,7 @@ def get_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger),
) -> JwtRefreshCommand:
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)
return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)

View File

@@ -2,43 +2,21 @@ from fastapi import APIRouter, Request, Depends
from fastapi.responses import ORJSONResponse
from starlette import status
from src.application.commands import JwtRefreshCommand
from src.application.domain.exceptions import ApplicationException
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.presentation.decorators import csrf_protect,rate_limit
from src.presentation.decorators import csrf_protect, rate_limit
from src.presentation.dependencies import get_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
def _clear_auth_cookies(response: ORJSONResponse) -> None:
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
return response
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except ApplicationException:
response = ORJSONResponse({'result': False}, status_code=401)
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
return response
response = ORJSONResponse({'result': True})
def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None:
response.set_cookie(
key='access_token',
value=access,
@@ -59,9 +37,37 @@ async def refresh_tokens(
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command),
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
_clear_auth_cookies(response)
return response
# Usage
# @jwt_router.get("/test")
# async def profile(auth: AuthContext = Depends(require_access_token)):
# return 'ok'
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except RefreshConcurrentException:
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
except ApplicationException as exc:
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
response = ORJSONResponse({'result': False}, status_code=401)
_clear_auth_cookies(response)
return response
raise
access, refresh = tokens
response = ORJSONResponse({'result': True})
_set_auth_cookies(response, access, refresh)
return response