feat: update

This commit is contained in:
2026-06-05 14:47:41 +03:00
parent fdae3ca554
commit 4d5506db4d
12 changed files with 52 additions and 244 deletions

View File

@@ -1,112 +1,29 @@
import asyncio
from datetime import datetime, timezone, timedelta
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.application.contracts import IJwtService, ILogger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class AdminJwtRefreshCommand:
_LOCK_PREFIX = 'admin:jwt:refresh:lock:'
_LOCK_TTL_SECONDS = 15
_LOCK_WAIT_ATTEMPTS = 40
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(self, *, refresh_token: str, ip: str | None, user_agent: str | None) -> tuple[str, str]:
now = datetime.now(timezone.utc)
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
async def __call__(self, *, refresh_token: str) -> tuple[str, str]:
payload = await self._jwt_service.decode_refresh_token(refresh_token)
admin = await self._unit_of_work.admin_user_repository.get_by_id(payload.sub)
sid = payload.sid
admin_user_id = payload.sub
jti = payload.jti
if not admin.is_active:
raise ApplicationException(status_code=403, message='Admin account is inactive')
lock_key = f'{self._LOCK_PREFIX}{sid}'
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
if not locked:
for _ in range(self._LOCK_WAIT_ATTEMPTS):
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
if await self._cache.get(lock_key) is None:
raise RefreshConcurrentException()
raise ApplicationException(status_code=429, message='Refresh in progress')
try:
return await self._refresh_locked(
sid=sid,
admin_user_id=admin_user_id,
jti=jti,
now=now,
ip=ip,
user_agent=user_agent,
)
finally:
await self._cache.delete(lock_key)
async def _refresh_locked(
self,
*,
sid: str,
admin_user_id: str,
jti: str,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
sess = await self._unit_of_work.admin_session_repository.get_by_sid(sid)
if sess is None:
raise ApplicationException(status_code=401, message='Session not found')
if sess.revoked_at is not None:
raise ApplicationException(status_code=401, message='Session revoked')
if sess.refresh_expires_at is None or sess.refresh_expires_at <= now:
raise ApplicationException(status_code=401, message='Session expired')
if str(sess.admin_user_id) != str(admin_user_id):
raise ApplicationException(status_code=401, message='Invalid session subject')
ok = await self._hash_service.verify(plain_value=jti, hashed_value=sess.refresh_jti_hash)
if not ok:
await self._unit_of_work.admin_session_repository.revoke_by_sid(sid=sid, now=now)
raise ApplicationException(status_code=401, message='Refresh token reuse detected')
admin = await self._unit_of_work.admin_user_repository.get_by_id(admin_user_id)
new_jti = str(ULID())
new_jti_hash = await self._hash_service.hash(value=new_jti)
new_refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
rotated = await self._unit_of_work.admin_session_repository.rotate_refresh_if_match(
sid=sid,
old_jti_hash=sess.refresh_jti_hash,
new_jti_hash=new_jti_hash,
new_refresh_expires_at=new_refresh_expires_at,
now=now,
ip=ip,
user_agent=user_agent,
)
if not rotated:
raise RefreshConcurrentException()
access = await self._jwt_service.create_access_token(
user_id=admin_user_id, sid=sid, role=admin.role
)
refresh = await self._jwt_service.create_refresh_token(
user_id=admin_user_id, sid=sid, refresh_jti=new_jti
)
access = await self._jwt_service.create_access_token(user_id=admin.id, role=admin.role)
refresh = await self._jwt_service.create_refresh_token(user_id=admin.id, role=admin.role)
self._logger.info(f'Admin tokens refreshed admin_user_id={admin.id}')
return access, refresh

View File

@@ -1,14 +1,11 @@
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from ulid import ULID
from datetime import datetime, timezone
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.domain.dto.admin_auth import AdminLoginDto
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
@@ -26,15 +23,7 @@ class AdminLoginCommand:
self._logger = logger
@transactional
async def __call__(
self,
*,
login: str,
password: str,
device_id: str | None,
ip: str | None,
user_agent: str | None,
) -> AdminLoginDto:
async def __call__(self, *, login: str, password: str) -> AdminLoginDto:
login = (login or '').strip()
if not login:
raise ApplicationException(status_code=400, message='Login is required')
@@ -51,33 +40,8 @@ class AdminLoginCommand:
now = datetime.now(timezone.utc)
await self._unit_of_work.admin_user_repository.update_last_login(admin.id, last_login_at=now)
resolved_device_id = device_id or str(ULID())
sid = str(ULID())
jti = str(ULID())
jti_hash = await self._hash_service.hash(value=jti)
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
await self._unit_of_work.admin_session_repository.upsert_by_device(
admin_user_id=admin.id,
device_id=resolved_device_id,
sid=sid,
refresh_jti_hash=jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
ip=ip,
now=now,
)
access_token = await self._jwt_service.create_access_token(
user_id=admin.id,
role=admin.role,
sid=sid,
)
refresh_token = await self._jwt_service.create_refresh_token(
user_id=admin.id,
sid=sid,
refresh_jti=jti,
)
access_token = await self._jwt_service.create_access_token(user_id=admin.id, role=admin.role)
refresh_token = await self._jwt_service.create_refresh_token(user_id=admin.id, role=admin.role)
self._logger.info(f'Admin logged in admin_user_id={admin.id}')
@@ -89,6 +53,5 @@ class AdminLoginCommand:
role=admin.role,
access_token=access_token,
refresh_token=refresh_token,
device_id=resolved_device_id,
last_login_at=now,
)

View File

@@ -1,30 +1,11 @@
from __future__ import annotations
from datetime import datetime, timezone
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IJwtService, ILogger
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
from src.application.contracts import ILogger
class AdminLogoutCommand:
def __init__(self, unit_of_work: IUnitOfWork, jwt_service: IJwtService, logger: ILogger):
self._unit_of_work = unit_of_work
self._jwt_service = jwt_service
def __init__(self, logger: ILogger):
self._logger = logger
@transactional
async def __call__(self, *, refresh_token: str | None) -> None:
if not refresh_token:
return
try:
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
except ApplicationException:
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
return
now = datetime.now(timezone.utc)
await self._unit_of_work.admin_session_repository.revoke_by_sid(sid=payload.sid, now=now)
self._logger.info(f'Logout: session revoked (sid={payload.sid}, admin_user_id={payload.sub})')
async def __call__(self) -> None:
self._logger.debug('Admin logout (stateless)')

View File

@@ -5,11 +5,11 @@ from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
class IJwtService(ABC):
@abstractmethod
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
async def create_access_token(self, user_id: str, *, role: str) -> str:
raise NotImplementedError
@abstractmethod
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
async def create_refresh_token(self, user_id: str, *, role: str) -> str:
raise NotImplementedError
@abstractmethod

View File

@@ -13,5 +13,4 @@ class AdminLoginDto:
role: str
access_token: str
refresh_token: str
device_id: str
last_login_at: datetime | None = None

View File

@@ -5,7 +5,6 @@ class AccessTokenPayload(BaseModel):
sub: str
type: str
role: str | None = None
sid: str | None = None
iat: int
nbf: int
exp: int
@@ -16,8 +15,7 @@ class AccessTokenPayload(BaseModel):
class RefreshTokenPayload(BaseModel):
sub: str
type: str
sid: str
jti: str
role: str
iat: int
nbf: int
exp: int

View File

@@ -20,7 +20,7 @@ class JwtService(IJwtService):
def _issuer(self) -> str | None:
return settings.ADMIN_JWT_ISSUER
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
async def create_access_token(self, user_id: str, *, role: str) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
@@ -32,8 +32,6 @@ class JwtService(IJwtService):
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
}
if sid:
payload['sid'] = sid
if self._issuer:
payload['iss'] = self._issuer
if settings.JWT_AUDIENCE:
@@ -41,15 +39,14 @@ class JwtService(IJwtService):
return await self._encode(payload, user_id=user_id, token_kind='access')
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
async def create_refresh_token(self, user_id: str, *, role: str) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
payload: dict[str, object] = {
'sub': user_id,
'type': 'refresh',
'sid': sid,
'jti': refresh_jti,
'role': role,
'iat': int(now.timestamp()),
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
@@ -82,7 +79,6 @@ class JwtService(IJwtService):
sub=str(payload['sub']),
type='access',
role=str(payload['role']) if payload.get('role') else None,
sid=str(payload['sid']) if payload.get('sid') else None,
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
@@ -100,8 +96,7 @@ class JwtService(IJwtService):
return RefreshTokenPayload(
sub=str(payload['sub']),
type='refresh',
sid=str(payload['sid']),
jti=str(payload['jti']),
role=str(payload['role']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
@@ -150,12 +145,10 @@ class JwtService(IJwtService):
if 'type' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: type')
token_type = payload.get('type')
if token_type == 'access' and 'role' not in payload:
if 'role' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: role')
if token_type == 'refresh':
for claim in ('sid', 'jti'):
if claim not in payload:
raise ApplicationException(status_code=401, message=f'Missing token claim: {claim}')
if token_type not in ('access', 'refresh'):
raise ApplicationException(status_code=401, message='Invalid token type')
return payload
except ExpiredSignatureError:
raise ApplicationException(status_code=401, message='Token expired')

View File

@@ -0,0 +1,15 @@
from ulid import ULID
def new_ulid() -> str:
return str(ULID())
def is_valid_ulid(value: str | None) -> bool:
if not value:
return False
try:
ULID.parse(value)
return True
except ValueError:
return False

View File

@@ -29,16 +29,3 @@ def set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> Non
domain=settings.ADMIN_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
def set_device_id_cookie(response: ORJSONResponse, device_id: str) -> None:
response.set_cookie(
key='device_id',
value=device_id,
httponly=True,
secure=settings.ADMIN_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.ADMIN_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)

View File

@@ -21,10 +21,9 @@ from src.application.commands import (
UpdatePurchaseRequestStatusCommand,
UploadOrganizationDocumentCommand,
)
from src.application.contracts import ICache, IHashService, IJwtService, ILogger
from src.application.contracts import IHashService, IJwtService, ILogger
from src.infrastructure.config import settings
from src.infrastructure.storage.s3_documents_service import S3DocumentsService
from src.presentation.dependencies.cache import get_cache
from src.presentation.dependencies.logger import get_logger
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
from src.presentation.dependencies.unit_of_work import get_unit_of_work
@@ -64,21 +63,17 @@ def get_admin_me_command(
def get_admin_logout_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
jwt_service: IJwtService = Depends(get_jwt_service),
logger: ILogger = Depends(get_logger),
) -> AdminLogoutCommand:
return AdminLogoutCommand(uow, jwt_service, logger)
return AdminLogoutCommand(logger)
def get_admin_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger),
) -> AdminJwtRefreshCommand:
return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
return AdminJwtRefreshCommand(uow, jwt_service, logger)
def get_create_organization_command(

View File

@@ -1,10 +1,9 @@
from fastapi import APIRouter, Depends, Request, status
from fastapi import APIRouter, Depends, status
from fastapi.responses import ORJSONResponse
from src.application.commands import AdminJwtRefreshCommand, AdminLoginCommand, GetAdminMeCommand
from src.application.domain.dto import AdminAuthContext
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.presentation.auth_cookies import set_auth_cookies, set_device_id_cookie
from src.presentation.auth_cookies import set_auth_cookies
from src.presentation.decorators.admin_auth import require_admin_access
from src.presentation.dependencies.commands import (
get_admin_jwt_refresh_command,
@@ -22,26 +21,12 @@ from src.presentation.schemas.admin_auth import (
auth_router = APIRouter(prefix='/auth', tags=['auth'])
def _client_ip(request: Request) -> str | None:
xff = request.headers.get('x-forwarded-for')
if xff:
return xff.split(',')[0].strip()
return request.client.host if request.client else None
@auth_router.post('/login', response_model=AdminLoginResponse, status_code=status.HTTP_200_OK)
async def admin_login(
body: AdminLoginRequest,
request: Request,
command: AdminLoginCommand = Depends(get_admin_login_command),
):
dto = await command(
login=body.login,
password=body.password,
device_id=request.cookies.get('device_id'),
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
dto = await command(login=body.login, password=body.password)
response = ORJSONResponse(
AdminLoginResponse(
access_token=dto.access_token,
@@ -54,31 +39,20 @@ async def admin_login(
).model_dump()
)
set_auth_cookies(response, dto.access_token, dto.refresh_token)
set_device_id_cookie(response, dto.device_id)
return response
@auth_router.post('/refresh', response_model=AdminRefreshResponse, status_code=status.HTTP_200_OK)
async def admin_refresh(
body: AdminRefreshRequest,
request: Request,
command: AdminJwtRefreshCommand = Depends(get_admin_jwt_refresh_command),
):
try:
access, refresh = await command(
refresh_token=body.refresh_token,
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
except RefreshConcurrentException:
raise ApplicationException(status_code=409, message='Refresh already in progress')
access, refresh = await command(refresh_token=body.refresh_token)
return AdminRefreshResponse(access_token=access, refresh_token=refresh)
@auth_router.post('/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
async def admin_logout():
"""Клиент удаляет access_token локально. Сервер stateless."""
return {'ok': True}

View File

@@ -3,20 +3,13 @@ from fastapi.responses import ORJSONResponse
from starlette import status
from src.application.commands import AdminJwtRefreshCommand
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.application.domain.exceptions import ApplicationException
from src.presentation.auth_cookies import clear_auth_cookies, set_auth_cookies
from src.presentation.dependencies.commands import get_admin_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['jwt'])
def _client_ip(request: Request) -> str | None:
xff = request.headers.get('x-forwarded-for')
if xff:
return xff.split(',')[0].strip()
return request.client.host if request.client else None
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
async def refresh_tokens(
request: Request,
@@ -29,13 +22,7 @@ async def refresh_tokens(
return response
try:
tokens = await command(
refresh_token=refresh_token,
ip=_client_ip(request),
user_agent=request.headers.get('user-agent'),
)
except RefreshConcurrentException:
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
access, refresh = await command(refresh_token=refresh_token)
except ApplicationException as exc:
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
response = ORJSONResponse({'result': False}, status_code=401)
@@ -43,7 +30,6 @@ async def refresh_tokens(
return response
raise
access, refresh = tokens
response = ORJSONResponse({'result': True})
set_auth_cookies(response, access, refresh)
return response