feat: add validation
This commit is contained in:
@@ -1,18 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from ulid import ULID
|
from ulid import ULID
|
||||||
from src.application.abstractions import IUnitOfWork
|
from src.application.abstractions import IUnitOfWork
|
||||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
|
||||||
from src.application.domain.dto import RefreshTokenPayload
|
from src.application.domain.dto import RefreshTokenPayload
|
||||||
from src.application.domain.exceptions import ApplicationException
|
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
|
||||||
from src.infrastructure.config import settings
|
from src.infrastructure.config import settings
|
||||||
from src.infrastructure.database.decorators import transactional
|
from src.infrastructure.database.decorators import transactional
|
||||||
|
|
||||||
|
|
||||||
class JwtRefreshCommand:
|
class JwtRefreshCommand:
|
||||||
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
|
_LOCK_PREFIX = 'jwt:refresh:lock:'
|
||||||
|
_LOCK_TTL_SECONDS = 15
|
||||||
|
_LOCK_WAIT_ATTEMPTS = 40
|
||||||
|
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unit_of_work: IUnitOfWork,
|
||||||
|
hash_service: IHashService,
|
||||||
|
jwt_service: IJwtService,
|
||||||
|
cache: ICache,
|
||||||
|
logger: ILogger,
|
||||||
|
):
|
||||||
self._unit_of_work = unit_of_work
|
self._unit_of_work = unit_of_work
|
||||||
self._hash_service = hash_service
|
self._hash_service = hash_service
|
||||||
self._jwt_service = jwt_service
|
self._jwt_service = jwt_service
|
||||||
|
self._cache = cache
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
|
|
||||||
@transactional
|
@transactional
|
||||||
@@ -25,6 +39,39 @@ class JwtRefreshCommand:
|
|||||||
user_id = payload.sub
|
user_id = payload.sub
|
||||||
jti = payload.jti
|
jti = payload.jti
|
||||||
|
|
||||||
|
lock_key = f'{self._LOCK_PREFIX}{sid}'
|
||||||
|
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
|
||||||
|
|
||||||
|
if not locked:
|
||||||
|
for _ in range(self._LOCK_WAIT_ATTEMPTS):
|
||||||
|
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
|
||||||
|
if await self._cache.get(lock_key) is None:
|
||||||
|
self._logger.info(f'Concurrent refresh skipped (sid={sid})')
|
||||||
|
raise RefreshConcurrentException()
|
||||||
|
raise ApplicationException(status_code=429, message='Refresh in progress')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._refresh_locked(
|
||||||
|
sid=sid,
|
||||||
|
user_id=user_id,
|
||||||
|
jti=jti,
|
||||||
|
now=now,
|
||||||
|
ip=ip,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await self._cache.delete(lock_key)
|
||||||
|
|
||||||
|
async def _refresh_locked(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sid: str,
|
||||||
|
user_id: str,
|
||||||
|
jti: str,
|
||||||
|
now: datetime,
|
||||||
|
ip: str | None,
|
||||||
|
user_agent: str | None,
|
||||||
|
) -> tuple[str, str]:
|
||||||
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
|
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
|
||||||
if sess is None:
|
if sess is None:
|
||||||
raise ApplicationException(status_code=401, message='Session not found')
|
raise ApplicationException(status_code=401, message='Session not found')
|
||||||
@@ -61,7 +108,8 @@ class JwtRefreshCommand:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not rotated:
|
if not rotated:
|
||||||
raise ApplicationException(status_code=401, message='Refresh already rotated')
|
self._logger.info(f'Refresh already rotated (sid={sid})')
|
||||||
|
raise RefreshConcurrentException()
|
||||||
|
|
||||||
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
|
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
|
||||||
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)
|
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)
|
||||||
|
|||||||
@@ -7,3 +7,4 @@ from src.application.domain.exceptions.not_found_exception import NotFoundExcept
|
|||||||
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
|
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
|
||||||
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
|
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
|
||||||
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
|
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
|
||||||
|
from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
from starlette import status
|
||||||
|
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshConcurrentException(ApplicationException):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
message='Refresh already handled',
|
||||||
|
)
|
||||||
@@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
from src.application.abstractions import IUnitOfWork
|
from src.application.abstractions import IUnitOfWork
|
||||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||||
from src.application.contracts import ILogger
|
from src.application.contracts import ILogger
|
||||||
|
from src.application.domain.exceptions import RefreshConcurrentException
|
||||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||||
|
|
||||||
|
|
||||||
@@ -20,8 +21,10 @@ class UnitOfWork(IUnitOfWork):
|
|||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
if exc_type:
|
if exc_type:
|
||||||
|
if not isinstance(exc_val, RefreshConcurrentException):
|
||||||
self._logger.error(str(exc_val))
|
self._logger.error(str(exc_val))
|
||||||
await self._session.rollback()
|
await self._session.rollback()
|
||||||
|
if not isinstance(exc_val, RefreshConcurrentException):
|
||||||
self._logger.error(f'Rollback: str{exc_val})')
|
self._logger.error(f'Rollback: str{exc_val})')
|
||||||
else:
|
else:
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ def get_jwt_refresh_command(
|
|||||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||||
hash_service: IHashService = Depends(get_hash_service),
|
hash_service: IHashService = Depends(get_hash_service),
|
||||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||||
|
cache: ICache = Depends(get_cache),
|
||||||
logger: ILogger = Depends(get_logger),
|
logger: ILogger = Depends(get_logger),
|
||||||
) -> JwtRefreshCommand:
|
) -> JwtRefreshCommand:
|
||||||
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)
|
return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
|
||||||
|
|||||||
@@ -2,43 +2,21 @@ from fastapi import APIRouter, Request, Depends
|
|||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from src.application.commands import JwtRefreshCommand
|
from src.application.commands import JwtRefreshCommand
|
||||||
from src.application.domain.exceptions import ApplicationException
|
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
|
||||||
from src.infrastructure.config import settings
|
from src.infrastructure.config import settings
|
||||||
from src.presentation.decorators import csrf_protect,rate_limit
|
from src.presentation.decorators import csrf_protect, rate_limit
|
||||||
from src.presentation.dependencies import get_jwt_refresh_command
|
from src.presentation.dependencies import get_jwt_refresh_command
|
||||||
|
|
||||||
|
|
||||||
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
|
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
|
||||||
|
|
||||||
|
|
||||||
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
def _clear_auth_cookies(response: ORJSONResponse) -> None:
|
||||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
|
||||||
@csrf_protect()
|
|
||||||
async def refresh_tokens(
|
|
||||||
request: Request,
|
|
||||||
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
|
|
||||||
):
|
|
||||||
refresh_token = request.cookies.get('refresh_token')
|
|
||||||
|
|
||||||
if not refresh_token:
|
|
||||||
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
|
|
||||||
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||||
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||||
return response
|
|
||||||
|
|
||||||
ip = request.client.host if request.client else None
|
|
||||||
user_agent = request.headers.get('user-agent')
|
|
||||||
|
|
||||||
try:
|
|
||||||
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
|
|
||||||
except ApplicationException:
|
|
||||||
response = ORJSONResponse({'result': False}, status_code=401)
|
|
||||||
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
|
||||||
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
|
||||||
return response
|
|
||||||
|
|
||||||
response = ORJSONResponse({'result': True})
|
|
||||||
|
|
||||||
|
def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None:
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key='access_token',
|
key='access_token',
|
||||||
value=access,
|
value=access,
|
||||||
@@ -59,9 +37,37 @@ async def refresh_tokens(
|
|||||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||||
|
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||||
|
@csrf_protect()
|
||||||
|
async def refresh_tokens(
|
||||||
|
request: Request,
|
||||||
|
command: JwtRefreshCommand = Depends(get_jwt_refresh_command),
|
||||||
|
):
|
||||||
|
refresh_token = request.cookies.get('refresh_token')
|
||||||
|
|
||||||
|
if not refresh_token:
|
||||||
|
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
|
||||||
|
_clear_auth_cookies(response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Usage
|
ip = request.client.host if request.client else None
|
||||||
# @jwt_router.get("/test")
|
user_agent = request.headers.get('user-agent')
|
||||||
# async def profile(auth: AuthContext = Depends(require_access_token)):
|
|
||||||
# return 'ok'
|
try:
|
||||||
|
tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
|
||||||
|
except RefreshConcurrentException:
|
||||||
|
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
|
||||||
|
except ApplicationException as exc:
|
||||||
|
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||||
|
response = ORJSONResponse({'result': False}, status_code=401)
|
||||||
|
_clear_auth_cookies(response)
|
||||||
|
return response
|
||||||
|
raise
|
||||||
|
|
||||||
|
access, refresh = tokens
|
||||||
|
response = ORJSONResponse({'result': True})
|
||||||
|
_set_auth_cookies(response, access, refresh)
|
||||||
|
return response
|
||||||
|
|||||||
Reference in New Issue
Block a user