feat: update

This commit is contained in:
2026-06-05 14:47:41 +03:00
parent fdae3ca554
commit 4d5506db4d
12 changed files with 52 additions and 244 deletions

View File

@@ -1,112 +1,29 @@
import asyncio
from datetime import datetime, timezone, timedelta
from ulid import ULID
from src.application.abstractions import IUnitOfWork from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger, ICache from src.application.contracts import IJwtService, ILogger
from src.application.domain.dto import RefreshTokenPayload from src.application.domain.exceptions import ApplicationException
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional from src.infrastructure.database.decorators import transactional
class AdminJwtRefreshCommand: class AdminJwtRefreshCommand:
_LOCK_PREFIX = 'admin:jwt:refresh:lock:'
_LOCK_TTL_SECONDS = 15
_LOCK_WAIT_ATTEMPTS = 40
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
def __init__( def __init__(
self, self,
unit_of_work: IUnitOfWork, unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService, jwt_service: IJwtService,
cache: ICache,
logger: ILogger, logger: ILogger,
): ):
self._unit_of_work = unit_of_work self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service self._jwt_service = jwt_service
self._cache = cache
self._logger = logger self._logger = logger
@transactional @transactional
async def __call__(self, *, refresh_token: str, ip: str | None, user_agent: str | None) -> tuple[str, str]: async def __call__(self, *, refresh_token: str) -> tuple[str, str]:
now = datetime.now(timezone.utc) payload = await self._jwt_service.decode_refresh_token(refresh_token)
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token) admin = await self._unit_of_work.admin_user_repository.get_by_id(payload.sub)
sid = payload.sid if not admin.is_active:
admin_user_id = payload.sub raise ApplicationException(status_code=403, message='Admin account is inactive')
jti = payload.jti
lock_key = f'{self._LOCK_PREFIX}{sid}' access = await self._jwt_service.create_access_token(user_id=admin.id, role=admin.role)
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS) refresh = await self._jwt_service.create_refresh_token(user_id=admin.id, role=admin.role)
if not locked: self._logger.info(f'Admin tokens refreshed admin_user_id={admin.id}')
for _ in range(self._LOCK_WAIT_ATTEMPTS):
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
if await self._cache.get(lock_key) is None:
raise RefreshConcurrentException()
raise ApplicationException(status_code=429, message='Refresh in progress')
try:
return await self._refresh_locked(
sid=sid,
admin_user_id=admin_user_id,
jti=jti,
now=now,
ip=ip,
user_agent=user_agent,
)
finally:
await self._cache.delete(lock_key)
async def _refresh_locked(
self,
*,
sid: str,
admin_user_id: str,
jti: str,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
sess = await self._unit_of_work.admin_session_repository.get_by_sid(sid)
if sess is None:
raise ApplicationException(status_code=401, message='Session not found')
if sess.revoked_at is not None:
raise ApplicationException(status_code=401, message='Session revoked')
if sess.refresh_expires_at is None or sess.refresh_expires_at <= now:
raise ApplicationException(status_code=401, message='Session expired')
if str(sess.admin_user_id) != str(admin_user_id):
raise ApplicationException(status_code=401, message='Invalid session subject')
ok = await self._hash_service.verify(plain_value=jti, hashed_value=sess.refresh_jti_hash)
if not ok:
await self._unit_of_work.admin_session_repository.revoke_by_sid(sid=sid, now=now)
raise ApplicationException(status_code=401, message='Refresh token reuse detected')
admin = await self._unit_of_work.admin_user_repository.get_by_id(admin_user_id)
new_jti = str(ULID())
new_jti_hash = await self._hash_service.hash(value=new_jti)
new_refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
rotated = await self._unit_of_work.admin_session_repository.rotate_refresh_if_match(
sid=sid,
old_jti_hash=sess.refresh_jti_hash,
new_jti_hash=new_jti_hash,
new_refresh_expires_at=new_refresh_expires_at,
now=now,
ip=ip,
user_agent=user_agent,
)
if not rotated:
raise RefreshConcurrentException()
access = await self._jwt_service.create_access_token(
user_id=admin_user_id, sid=sid, role=admin.role
)
refresh = await self._jwt_service.create_refresh_token(
user_id=admin_user_id, sid=sid, refresh_jti=new_jti
)
return access, refresh return access, refresh

View File

@@ -1,14 +1,11 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.domain.dto.admin_auth import AdminLoginDto from src.application.domain.dto.admin_auth import AdminLoginDto
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional from src.infrastructure.database.decorators import transactional
@@ -26,15 +23,7 @@ class AdminLoginCommand:
self._logger = logger self._logger = logger
@transactional @transactional
async def __call__( async def __call__(self, *, login: str, password: str) -> AdminLoginDto:
self,
*,
login: str,
password: str,
device_id: str | None,
ip: str | None,
user_agent: str | None,
) -> AdminLoginDto:
login = (login or '').strip() login = (login or '').strip()
if not login: if not login:
raise ApplicationException(status_code=400, message='Login is required') raise ApplicationException(status_code=400, message='Login is required')
@@ -51,33 +40,8 @@ class AdminLoginCommand:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await self._unit_of_work.admin_user_repository.update_last_login(admin.id, last_login_at=now) await self._unit_of_work.admin_user_repository.update_last_login(admin.id, last_login_at=now)
resolved_device_id = device_id or str(ULID()) access_token = await self._jwt_service.create_access_token(user_id=admin.id, role=admin.role)
sid = str(ULID()) refresh_token = await self._jwt_service.create_refresh_token(user_id=admin.id, role=admin.role)
jti = str(ULID())
jti_hash = await self._hash_service.hash(value=jti)
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
await self._unit_of_work.admin_session_repository.upsert_by_device(
admin_user_id=admin.id,
device_id=resolved_device_id,
sid=sid,
refresh_jti_hash=jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
ip=ip,
now=now,
)
access_token = await self._jwt_service.create_access_token(
user_id=admin.id,
role=admin.role,
sid=sid,
)
refresh_token = await self._jwt_service.create_refresh_token(
user_id=admin.id,
sid=sid,
refresh_jti=jti,
)
self._logger.info(f'Admin logged in admin_user_id={admin.id}') self._logger.info(f'Admin logged in admin_user_id={admin.id}')
@@ -89,6 +53,5 @@ class AdminLoginCommand:
role=admin.role, role=admin.role,
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
device_id=resolved_device_id,
last_login_at=now, last_login_at=now,
) )

View File

@@ -1,30 +1,11 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from src.application.contracts import ILogger
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IJwtService, ILogger
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class AdminLogoutCommand: class AdminLogoutCommand:
def __init__(self, unit_of_work: IUnitOfWork, jwt_service: IJwtService, logger: ILogger): def __init__(self, logger: ILogger):
self._unit_of_work = unit_of_work
self._jwt_service = jwt_service
self._logger = logger self._logger = logger
@transactional async def __call__(self) -> None:
async def __call__(self, *, refresh_token: str | None) -> None: self._logger.debug('Admin logout (stateless)')
if not refresh_token:
return
try:
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
except ApplicationException:
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
return
now = datetime.now(timezone.utc)
await self._unit_of_work.admin_session_repository.revoke_by_sid(sid=payload.sid, now=now)
self._logger.info(f'Logout: session revoked (sid={payload.sid}, admin_user_id={payload.sub})')

View File

@@ -5,11 +5,11 @@ from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
class IJwtService(ABC): class IJwtService(ABC):
@abstractmethod @abstractmethod
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str: async def create_access_token(self, user_id: str, *, role: str) -> str:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str: async def create_refresh_token(self, user_id: str, *, role: str) -> str:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod

View File

@@ -13,5 +13,4 @@ class AdminLoginDto:
role: str role: str
access_token: str access_token: str
refresh_token: str refresh_token: str
device_id: str
last_login_at: datetime | None = None last_login_at: datetime | None = None

View File

@@ -5,7 +5,6 @@ class AccessTokenPayload(BaseModel):
sub: str sub: str
type: str type: str
role: str | None = None role: str | None = None
sid: str | None = None
iat: int iat: int
nbf: int nbf: int
exp: int exp: int
@@ -16,8 +15,7 @@ class AccessTokenPayload(BaseModel):
class RefreshTokenPayload(BaseModel): class RefreshTokenPayload(BaseModel):
sub: str sub: str
type: str type: str
sid: str role: str
jti: str
iat: int iat: int
nbf: int nbf: int
exp: int exp: int

View File

@@ -20,7 +20,7 @@ class JwtService(IJwtService):
def _issuer(self) -> str | None: def _issuer(self) -> str | None:
return settings.ADMIN_JWT_ISSUER return settings.ADMIN_JWT_ISSUER
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str: async def create_access_token(self, user_id: str, *, role: str) -> str:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS)) exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
@@ -32,8 +32,6 @@ class JwtService(IJwtService):
'nbf': int(now.timestamp()), 'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()), 'exp': int(exp.timestamp()),
} }
if sid:
payload['sid'] = sid
if self._issuer: if self._issuer:
payload['iss'] = self._issuer payload['iss'] = self._issuer
if settings.JWT_AUDIENCE: if settings.JWT_AUDIENCE:
@@ -41,15 +39,14 @@ class JwtService(IJwtService):
return await self._encode(payload, user_id=user_id, token_kind='access') return await self._encode(payload, user_id=user_id, token_kind='access')
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str: async def create_refresh_token(self, user_id: str, *, role: str) -> str:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS)) exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
payload: dict[str, object] = { payload: dict[str, object] = {
'sub': user_id, 'sub': user_id,
'type': 'refresh', 'type': 'refresh',
'sid': sid, 'role': role,
'jti': refresh_jti,
'iat': int(now.timestamp()), 'iat': int(now.timestamp()),
'nbf': int(now.timestamp()), 'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()), 'exp': int(exp.timestamp()),
@@ -82,7 +79,6 @@ class JwtService(IJwtService):
sub=str(payload['sub']), sub=str(payload['sub']),
type='access', type='access',
role=str(payload['role']) if payload.get('role') else None, role=str(payload['role']) if payload.get('role') else None,
sid=str(payload['sid']) if payload.get('sid') else None,
iat=int(payload['iat']), iat=int(payload['iat']),
nbf=int(payload['nbf']), nbf=int(payload['nbf']),
exp=int(payload['exp']), exp=int(payload['exp']),
@@ -100,8 +96,7 @@ class JwtService(IJwtService):
return RefreshTokenPayload( return RefreshTokenPayload(
sub=str(payload['sub']), sub=str(payload['sub']),
type='refresh', type='refresh',
sid=str(payload['sid']), role=str(payload['role']),
jti=str(payload['jti']),
iat=int(payload['iat']), iat=int(payload['iat']),
nbf=int(payload['nbf']), nbf=int(payload['nbf']),
exp=int(payload['exp']), exp=int(payload['exp']),
@@ -150,12 +145,10 @@ class JwtService(IJwtService):
if 'type' not in payload: if 'type' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: type') raise ApplicationException(status_code=401, message='Missing token claim: type')
token_type = payload.get('type') token_type = payload.get('type')
if token_type == 'access' and 'role' not in payload: if 'role' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: role') raise ApplicationException(status_code=401, message='Missing token claim: role')
if token_type == 'refresh': if token_type not in ('access', 'refresh'):
for claim in ('sid', 'jti'): raise ApplicationException(status_code=401, message='Invalid token type')
if claim not in payload:
raise ApplicationException(status_code=401, message=f'Missing token claim: {claim}')
return payload return payload
except ExpiredSignatureError: except ExpiredSignatureError:
raise ApplicationException(status_code=401, message='Token expired') raise ApplicationException(status_code=401, message='Token expired')

View File

@@ -0,0 +1,15 @@
from ulid import ULID
def new_ulid() -> str:
return str(ULID())
def is_valid_ulid(value: str | None) -> bool:
if not value:
return False
try:
ULID.parse(value)
return True
except ValueError:
return False

View File

@@ -29,16 +29,3 @@ def set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> Non
domain=settings.ADMIN_COOKIE_DOMAIN, domain=settings.ADMIN_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS), max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
) )
def set_device_id_cookie(response: ORJSONResponse, device_id: str) -> None:
response.set_cookie(
key='device_id',
value=device_id,
httponly=True,
secure=settings.ADMIN_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.ADMIN_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)

View File

@@ -21,10 +21,9 @@ from src.application.commands import (
UpdatePurchaseRequestStatusCommand, UpdatePurchaseRequestStatusCommand,
UploadOrganizationDocumentCommand, UploadOrganizationDocumentCommand,
) )
from src.application.contracts import ICache, IHashService, IJwtService, ILogger from src.application.contracts import IHashService, IJwtService, ILogger
from src.infrastructure.config import settings from src.infrastructure.config import settings
from src.infrastructure.storage.s3_documents_service import S3DocumentsService from src.infrastructure.storage.s3_documents_service import S3DocumentsService
from src.presentation.dependencies.cache import get_cache
from src.presentation.dependencies.logger import get_logger from src.presentation.dependencies.logger import get_logger
from src.presentation.dependencies.security import get_hash_service, get_jwt_service from src.presentation.dependencies.security import get_hash_service, get_jwt_service
from src.presentation.dependencies.unit_of_work import get_unit_of_work from src.presentation.dependencies.unit_of_work import get_unit_of_work
@@ -64,21 +63,17 @@ def get_admin_me_command(
def get_admin_logout_command( def get_admin_logout_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
jwt_service: IJwtService = Depends(get_jwt_service),
logger: ILogger = Depends(get_logger), logger: ILogger = Depends(get_logger),
) -> AdminLogoutCommand: ) -> AdminLogoutCommand:
return AdminLogoutCommand(uow, jwt_service, logger) return AdminLogoutCommand(logger)
def get_admin_jwt_refresh_command( def get_admin_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work), uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service), jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger), logger: ILogger = Depends(get_logger),
) -> AdminJwtRefreshCommand: ) -> AdminJwtRefreshCommand:
return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger) return AdminJwtRefreshCommand(uow, jwt_service, logger)
def get_create_organization_command( def get_create_organization_command(

View File

@@ -1,10 +1,9 @@
from fastapi import APIRouter, Depends, Request, status from fastapi import APIRouter, Depends, status
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from src.application.commands import AdminJwtRefreshCommand, AdminLoginCommand, GetAdminMeCommand from src.application.commands import AdminJwtRefreshCommand, AdminLoginCommand, GetAdminMeCommand
from src.application.domain.dto import AdminAuthContext from src.application.domain.dto import AdminAuthContext
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException from src.presentation.auth_cookies import set_auth_cookies
from src.presentation.auth_cookies import set_auth_cookies, set_device_id_cookie
from src.presentation.decorators.admin_auth import require_admin_access from src.presentation.decorators.admin_auth import require_admin_access
from src.presentation.dependencies.commands import ( from src.presentation.dependencies.commands import (
get_admin_jwt_refresh_command, get_admin_jwt_refresh_command,
@@ -22,26 +21,12 @@ from src.presentation.schemas.admin_auth import (
auth_router = APIRouter(prefix='/auth', tags=['auth']) auth_router = APIRouter(prefix='/auth', tags=['auth'])
def _client_ip(request: Request) -> str | None:
xff = request.headers.get('x-forwarded-for')
if xff:
return xff.split(',')[0].strip()
return request.client.host if request.client else None
@auth_router.post('/login', response_model=AdminLoginResponse, status_code=status.HTTP_200_OK) @auth_router.post('/login', response_model=AdminLoginResponse, status_code=status.HTTP_200_OK)
async def admin_login( async def admin_login(
body: AdminLoginRequest, body: AdminLoginRequest,
request: Request,
command: AdminLoginCommand = Depends(get_admin_login_command), command: AdminLoginCommand = Depends(get_admin_login_command),
): ):
dto = await command( dto = await command(login=body.login, password=body.password)
login=body.login,
password=body.password,
device_id=request.cookies.get('device_id'),
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
response = ORJSONResponse( response = ORJSONResponse(
AdminLoginResponse( AdminLoginResponse(
access_token=dto.access_token, access_token=dto.access_token,
@@ -54,31 +39,20 @@ async def admin_login(
).model_dump() ).model_dump()
) )
set_auth_cookies(response, dto.access_token, dto.refresh_token) set_auth_cookies(response, dto.access_token, dto.refresh_token)
set_device_id_cookie(response, dto.device_id)
return response return response
@auth_router.post('/refresh', response_model=AdminRefreshResponse, status_code=status.HTTP_200_OK) @auth_router.post('/refresh', response_model=AdminRefreshResponse, status_code=status.HTTP_200_OK)
async def admin_refresh( async def admin_refresh(
body: AdminRefreshRequest, body: AdminRefreshRequest,
request: Request,
command: AdminJwtRefreshCommand = Depends(get_admin_jwt_refresh_command), command: AdminJwtRefreshCommand = Depends(get_admin_jwt_refresh_command),
): ):
try: access, refresh = await command(refresh_token=body.refresh_token)
access, refresh = await command(
refresh_token=body.refresh_token,
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
except RefreshConcurrentException:
raise ApplicationException(status_code=409, message='Refresh already in progress')
return AdminRefreshResponse(access_token=access, refresh_token=refresh) return AdminRefreshResponse(access_token=access, refresh_token=refresh)
@auth_router.post('/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @auth_router.post('/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
async def admin_logout(): async def admin_logout():
"""Клиент удаляет access_token локально. Сервер stateless."""
return {'ok': True} return {'ok': True}

View File

@@ -3,20 +3,13 @@ from fastapi.responses import ORJSONResponse
from starlette import status from starlette import status
from src.application.commands import AdminJwtRefreshCommand from src.application.commands import AdminJwtRefreshCommand
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException from src.application.domain.exceptions import ApplicationException
from src.presentation.auth_cookies import clear_auth_cookies, set_auth_cookies from src.presentation.auth_cookies import clear_auth_cookies, set_auth_cookies
from src.presentation.dependencies.commands import get_admin_jwt_refresh_command from src.presentation.dependencies.commands import get_admin_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['jwt']) jwt_router = APIRouter(prefix='/jwt', tags=['jwt'])
def _client_ip(request: Request) -> str | None:
xff = request.headers.get('x-forwarded-for')
if xff:
return xff.split(',')[0].strip()
return request.client.host if request.client else None
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
async def refresh_tokens( async def refresh_tokens(
request: Request, request: Request,
@@ -29,13 +22,7 @@ async def refresh_tokens(
return response return response
try: try:
tokens = await command( access, refresh = await command(refresh_token=refresh_token)
refresh_token=refresh_token,
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
except RefreshConcurrentException:
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
except ApplicationException as exc: except ApplicationException as exc:
if exc.status_code == status.HTTP_401_UNAUTHORIZED: if exc.status_code == status.HTTP_401_UNAUTHORIZED:
response = ORJSONResponse({'result': False}, status_code=401) response = ORJSONResponse({'result': False}, status_code=401)
@@ -43,7 +30,6 @@ async def refresh_tokens(
return response return response
raise raise
access, refresh = tokens
response = ORJSONResponse({'result': True}) response = ORJSONResponse({'result': True})
set_auth_cookies(response, access, refresh) set_auth_cookies(response, access, refresh)
return response return response