diff --git a/src/application/commands/jwt_refresh.py b/src/application/commands/jwt_refresh.py index 2de4700..bcb7d9d 100644 --- a/src/application/commands/jwt_refresh.py +++ b/src/application/commands/jwt_refresh.py @@ -1,18 +1,32 @@ +import asyncio from datetime import datetime, timezone, timedelta from ulid import ULID from src.application.abstractions import IUnitOfWork -from src.application.contracts import IHashService, IJwtService, ILogger +from src.application.contracts import IHashService, IJwtService, ILogger, ICache from src.application.domain.dto import RefreshTokenPayload -from src.application.domain.exceptions import ApplicationException +from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException from src.infrastructure.config import settings from src.infrastructure.database.decorators import transactional class JwtRefreshCommand: - def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger): + _LOCK_PREFIX = 'jwt:refresh:lock:' + _LOCK_TTL_SECONDS = 15 + _LOCK_WAIT_ATTEMPTS = 40 + _LOCK_WAIT_INTERVAL_SECONDS = 0.05 + + def __init__( + self, + unit_of_work: IUnitOfWork, + hash_service: IHashService, + jwt_service: IJwtService, + cache: ICache, + logger: ILogger, + ): self._unit_of_work = unit_of_work self._hash_service = hash_service self._jwt_service = jwt_service + self._cache = cache self._logger = logger @transactional @@ -25,6 +39,39 @@ class JwtRefreshCommand: user_id = payload.sub jti = payload.jti + lock_key = f'{self._LOCK_PREFIX}{sid}' + locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS) + + if not locked: + for _ in range(self._LOCK_WAIT_ATTEMPTS): + await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS) + if await self._cache.get(lock_key) is None: + self._logger.info(f'Concurrent refresh skipped (sid={sid})') + raise RefreshConcurrentException() + raise ApplicationException(status_code=429, message='Refresh in progress') + + try: + return await self._refresh_locked( + sid=sid, + user_id=user_id, + jti=jti, + now=now, + ip=ip, + user_agent=user_agent, + ) + finally: + await self._cache.delete(lock_key) + + async def _refresh_locked( + self, + *, + sid: str, + user_id: str, + jti: str, + now: datetime, + ip: str | None, + user_agent: str | None, + ) -> tuple[str, str]: sess = await self._unit_of_work.session_repository.get_by_sid(sid) if sess is None: raise ApplicationException(status_code=401, message='Session not found') @@ -61,7 +108,8 @@ class JwtRefreshCommand: ) if not rotated: - raise ApplicationException(status_code=401, message='Refresh already rotated') + self._logger.info(f'Refresh already rotated (sid={sid})') + raise RefreshConcurrentException() access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid) refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti) diff --git a/src/application/domain/exceptions/__init__.py b/src/application/domain/exceptions/__init__.py index 5305794..67deede 100644 --- a/src/application/domain/exceptions/__init__.py +++ b/src/application/domain/exceptions/__init__.py @@ -6,4 +6,5 @@ from src.application.domain.exceptions.internal_server_exception import Internal from src.application.domain.exceptions.not_found_exception import NotFoundException from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException -from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException \ No newline at end of file +from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException +from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException \ No newline at end of file diff --git a/src/application/domain/exceptions/refresh_concurrent_exception.py b/src/application/domain/exceptions/refresh_concurrent_exception.py new file mode 100644 index 0000000..6885c74 --- /dev/null +++ b/src/application/domain/exceptions/refresh_concurrent_exception.py @@ -0,0 +1,10 @@ +from starlette import status +from src.application.domain.exceptions.application_exception import ApplicationException + + +class RefreshConcurrentException(ApplicationException): + def __init__(self) -> None: + super().__init__( + status_code=status.HTTP_200_OK, + message='Refresh already handled', + ) diff --git a/src/infrastructure/database/unit_of_work.py b/src/infrastructure/database/unit_of_work.py index 3e2363c..793fede 100644 --- a/src/infrastructure/database/unit_of_work.py +++ b/src/infrastructure/database/unit_of_work.py @@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from src.application.abstractions import IUnitOfWork from src.application.abstractions.repositories import IUserRepository, ISessionRepository from src.application.contracts import ILogger +from src.application.domain.exceptions import RefreshConcurrentException from src.infrastructure.database.repositories import UserRepository, SessionRepository @@ -20,9 +21,11 @@ class UnitOfWork(IUnitOfWork): async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type: - self._logger.error(str(exc_val)) + if not isinstance(exc_val, RefreshConcurrentException): + self._logger.error(str(exc_val)) await self._session.rollback() - self._logger.error(f'Rollback: str{exc_val})') + if not isinstance(exc_val, RefreshConcurrentException): + self._logger.error(f'Rollback: str{exc_val})') else: await self._session.flush() await self._session.commit() diff --git a/src/presentation/dependencies/commands.py b/src/presentation/dependencies/commands.py index e0f65f4..aceecc0 100644 --- a/src/presentation/dependencies/commands.py +++ b/src/presentation/dependencies/commands.py @@ -93,6 +93,7 @@ def get_jwt_refresh_command( uow: IUnitOfWork = Depends(get_unit_of_work), hash_service: IHashService = Depends(get_hash_service), jwt_service: IJwtService = Depends(get_jwt_service), + cache: ICache = Depends(get_cache), logger: ILogger = Depends(get_logger), ) -> JwtRefreshCommand: - return JwtRefreshCommand(uow, hash_service, jwt_service, logger) + return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger) diff --git a/src/presentation/routing/jwt.py b/src/presentation/routing/jwt.py index b953a74..7074b83 100644 --- a/src/presentation/routing/jwt.py +++ b/src/presentation/routing/jwt.py @@ -2,43 +2,21 @@ from fastapi import APIRouter, Request, Depends from fastapi.responses import ORJSONResponse from starlette import status from src.application.commands import JwtRefreshCommand -from src.application.domain.exceptions import ApplicationException +from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException from src.infrastructure.config import settings -from src.presentation.decorators import csrf_protect,rate_limit +from src.presentation.decorators import csrf_protect, rate_limit from src.presentation.dependencies import get_jwt_refresh_command jwt_router = APIRouter(prefix='/jwt', tags=['Jwt']) -@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) -@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip') -@csrf_protect() -async def refresh_tokens( - request: Request, - command: JwtRefreshCommand = Depends(get_jwt_refresh_command) -): - refresh_token = request.cookies.get('refresh_token') +def _clear_auth_cookies(response: ORJSONResponse) -> None: + response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) + response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) - if not refresh_token: - response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401) - response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) - response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) - return response - - ip = request.client.host if request.client else None - user_agent = request.headers.get('user-agent') - - try: - access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent) - except ApplicationException: - response = ORJSONResponse({'result': False}, status_code=401) - response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) - response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN) - return response - - response = ORJSONResponse({'result': True}) +def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None: response.set_cookie( key='access_token', value=access, @@ -59,9 +37,37 @@ async def refresh_tokens( domain=settings.AUTH_COOKIE_DOMAIN, max_age=int(settings.JWT_REFRESH_TTL_SECONDS), ) - return response -# Usage -# @jwt_router.get("/test") -# async def profile(auth: AuthContext = Depends(require_access_token)): -# return 'ok' \ No newline at end of file + +@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) +@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip') +@csrf_protect() +async def refresh_tokens( + request: Request, + command: JwtRefreshCommand = Depends(get_jwt_refresh_command), +): + refresh_token = request.cookies.get('refresh_token') + + if not refresh_token: + response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401) + _clear_auth_cookies(response) + return response + + ip = request.client.host if request.client else None + user_agent = request.headers.get('user-agent') + + try: + tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent) + except RefreshConcurrentException: + return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK) + except ApplicationException as exc: + if exc.status_code == status.HTTP_401_UNAUTHORIZED: + response = ORJSONResponse({'result': False}, status_code=401) + _clear_auth_cookies(response) + return response + raise + + access, refresh = tokens + response = ORJSONResponse({'result': True}) + _set_auth_cookies(response, access, refresh) + return response