From 62a48c3f702b960a654ed12cb81399bec2c4ff08 Mon Sep 17 00:00:00 2001 From: Noloquideus Date: Thu, 4 Jun 2026 18:07:08 +0300 Subject: [PATCH] feat: add import --- pyproject.toml | 1 - src/application/commands/__init__.py | 4 + src/application/contracts/__init__.py | 1 - src/application/contracts/i_csrf_service.py | 26 ------- src/application/contracts/i_jwt_service.py | 12 ++- src/application/domain/dto/__init__.py | 2 +- src/application/domain/dto/token.py | 13 ++++ src/infrastructure/security/__init__.py | 1 - src/infrastructure/security/csrf.py | 81 --------------------- src/infrastructure/security/jwt.py | 58 ++++++++++++++- src/presentation/decorators/csrf.py | 61 ---------------- src/presentation/dependencies/commands.py | 23 +++++- uv.lock | 11 --- 13 files changed, 104 insertions(+), 190 deletions(-) delete mode 100644 src/application/contracts/i_csrf_service.py delete mode 100644 src/infrastructure/security/csrf.py delete mode 100644 src/presentation/decorators/csrf.py diff --git a/pyproject.toml b/pyproject.toml index 47cb8be..280fe73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "fastapi==0.128.7", "granian==2.6.1", "hvac==2.4.0", - "itsdangerous>=2.2.0", "orjson==3.11.7", "pydantic-settings==2.12.0", "python-jose==3.5.0", diff --git a/src/application/commands/__init__.py b/src/application/commands/__init__.py index 51d98cf..a450151 100644 --- a/src/application/commands/__init__.py +++ b/src/application/commands/__init__.py @@ -1,4 +1,6 @@ from src.application.commands.admin_login import AdminLoginCommand +from src.application.commands.admin_logout import AdminLogoutCommand +from src.application.commands.admin_jwt_refresh import AdminJwtRefreshCommand from src.application.commands.get_admin_me import GetAdminMeCommand from src.application.commands.create_organization import CreateOrganizationCommand from src.application.commands.create_organization_wallets import CreateOrganizationWalletsCommand @@ -21,6 +23,8 @@ from src.application.commands.purchase_request_commands import ( __all__ = [ 'AdminLoginCommand', + 'AdminLogoutCommand', + 'AdminJwtRefreshCommand', 'GetAdminMeCommand', 'CreateOrganizationCommand', 'CreateOrganizationWalletsCommand', diff --git a/src/application/contracts/__init__.py b/src/application/contracts/__init__.py index 0fd1d6d..31751c8 100644 --- a/src/application/contracts/__init__.py +++ b/src/application/contracts/__init__.py @@ -2,6 +2,5 @@ from src.application.contracts.i_hash_service import IHashService from src.application.contracts.i_logger import ILogger from src.application.contracts.i_user_service import IUserService from src.application.contracts.i_jwt_service import IJwtService -from src.application.contracts.i_csrf_service import ICsrfService from src.application.contracts.i_cache import ICache from src.application.contracts.i_queue_messanger import IQueueMessanger \ No newline at end of file diff --git a/src/application/contracts/i_csrf_service.py b/src/application/contracts/i_csrf_service.py deleted file mode 100644 index a493d60..0000000 --- a/src/application/contracts/i_csrf_service.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Any, Optional, Mapping - - -class ICsrfService(ABC): - @abstractmethod - def issue(self, subject: Optional[str] = None) -> str: - raise NotImplementedError - - @abstractmethod - def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]: - raise NotImplementedError - - @abstractmethod - def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]: - raise NotImplementedError - - @abstractmethod - def verify_pair( - self, - cookie_token: Optional[str], - header_token: Optional[str], - expected_subject: Optional[str] = None, - ) -> None: - raise NotImplementedError diff --git a/src/application/contracts/i_jwt_service.py b/src/application/contracts/i_jwt_service.py index 82359f2..34d5298 100644 --- a/src/application/contracts/i_jwt_service.py +++ b/src/application/contracts/i_jwt_service.py @@ -1,13 +1,21 @@ from abc import ABC, abstractmethod -from src.application.domain.dto import AccessTokenPayload +from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload class IJwtService(ABC): @abstractmethod - async def create_access_token(self, user_id: str, *, role: str) -> str: + async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str: + raise NotImplementedError + + @abstractmethod + async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str: raise NotImplementedError @abstractmethod async def decode_access_token(self, token: str) -> AccessTokenPayload: raise NotImplementedError + + @abstractmethod + async def decode_refresh_token(self, token: str) -> RefreshTokenPayload: + raise NotImplementedError diff --git a/src/application/domain/dto/__init__.py b/src/application/domain/dto/__init__.py index 5d1b9fa..de4c517 100644 --- a/src/application/domain/dto/__init__.py +++ b/src/application/domain/dto/__init__.py @@ -1,4 +1,4 @@ from src.application.domain.dto.admin_auth import AdminLoginDto -from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext +from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext, RefreshTokenPayload from src.application.domain.dto.keys import JwtKeySet, JwtKeyPair from src.application.domain.dto.user import UserCreatedDto, UserLoginDto diff --git a/src/application/domain/dto/token.py b/src/application/domain/dto/token.py index dcdcd94..59e3fb3 100644 --- a/src/application/domain/dto/token.py +++ b/src/application/domain/dto/token.py @@ -5,6 +5,19 @@ class AccessTokenPayload(BaseModel): sub: str type: str role: str | None = None + sid: str | None = None + iat: int + nbf: int + exp: int + iss: str | None = None + aud: str | None = None + + +class RefreshTokenPayload(BaseModel): + sub: str + type: str + sid: str + jti: str iat: int nbf: int exp: int diff --git a/src/infrastructure/security/__init__.py b/src/infrastructure/security/__init__.py index 6dc434f..c0c586f 100644 --- a/src/infrastructure/security/__init__.py +++ b/src/infrastructure/security/__init__.py @@ -1,3 +1,2 @@ from src.infrastructure.security.jwt import JwtService -from src.infrastructure.security.csrf import CsrfService from src.infrastructure.security.hash import HashService \ No newline at end of file diff --git a/src/infrastructure/security/csrf.py b/src/infrastructure/security/csrf.py deleted file mode 100644 index 28f8faf..0000000 --- a/src/infrastructure/security/csrf.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations -import secrets -from typing import Any, Optional, Mapping -from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature -from src.application.contracts import ICsrfService -from src.application.domain.exceptions import ApplicationException -from src.infrastructure.config.settings import settings - - -class CsrfService(ICsrfService): - COOKIE_NAME = "csrf_token" - HEADER_NAME = "X-CSRF-Token" - SALT = "csrf" - TTL_SECONDS = 3600 - - def __init__(self) -> None: - self._serializer = URLSafeTimedSerializer( - secret_key=settings.CSRF_SECRET_KEY, - salt=self.SALT, - ) - - @property - def cookie_name(self) -> str: - return self.COOKIE_NAME - - @property - def header_name(self) -> str: - return self.HEADER_NAME - - @property - def ttl_seconds(self) -> int: - return self.TTL_SECONDS - - def issue(self, subject: Optional[str] = None) -> str: - payload = { - "sub": subject, - "nonce": secrets.token_urlsafe(32), - } - return self._serializer.dumps(payload) - - def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]: - try: - data = self._serializer.loads(token, max_age=self.TTL_SECONDS) - except SignatureExpired: - raise ApplicationException( - status_code=403, - message="CSRF token expired", - ) - except BadSignature: - raise ApplicationException( - status_code=403, - message="CSRF token invalid", - ) - - if expected_subject is not None and data.get("sub") != expected_subject: - raise ApplicationException( - status_code=403, - message="CSRF token subject mismatch", - ) - - return data - - def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]: - cookie_token = cookies.get(self.COOKIE_NAME) - header_token = headers.get(self.HEADER_NAME) - return cookie_token, header_token - - def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None: - if not cookie_token or not header_token: - raise ApplicationException( - status_code=403, - message="CSRF token missing", - ) - - if not secrets.compare_digest(cookie_token, header_token): - raise ApplicationException( - status_code=403, - message="CSRF token mismatch", - ) - - self.verify(cookie_token, expected_subject=expected_subject) diff --git a/src/infrastructure/security/jwt.py b/src/infrastructure/security/jwt.py index d5bfeec..e5c59bc 100644 --- a/src/infrastructure/security/jwt.py +++ b/src/infrastructure/security/jwt.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone, timedelta from jose import jwt, ExpiredSignatureError, JWTError from src.application.contracts import ILogger, IJwtService -from src.application.domain.dto import AccessTokenPayload +from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload from src.application.domain.exceptions import ApplicationException from src.infrastructure.config.settings import settings from src.infrastructure.vault import JwtKeyStore @@ -20,7 +20,7 @@ class JwtService(IJwtService): def _issuer(self) -> str | None: return settings.ADMIN_JWT_ISSUER - async def create_access_token(self, user_id: str, *, role: str) -> str: + async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str: now = datetime.now(timezone.utc) exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS)) @@ -32,15 +32,40 @@ class JwtService(IJwtService): 'nbf': int(now.timestamp()), 'exp': int(exp.timestamp()), } + if sid: + payload['sid'] = sid if self._issuer: payload['iss'] = self._issuer if settings.JWT_AUDIENCE: payload['aud'] = settings.JWT_AUDIENCE + return await self._encode(payload, user_id=user_id, token_kind='access') + + async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str: + now = datetime.now(timezone.utc) + exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS)) + + payload: dict[str, object] = { + 'sub': user_id, + 'type': 'refresh', + 'sid': sid, + 'jti': refresh_jti, + 'iat': int(now.timestamp()), + 'nbf': int(now.timestamp()), + 'exp': int(exp.timestamp()), + } + if self._issuer: + payload['iss'] = self._issuer + if settings.JWT_AUDIENCE: + payload['aud'] = settings.JWT_AUDIENCE + + return await self._encode(payload, user_id=user_id, token_kind='refresh') + + async def _encode(self, payload: dict[str, object], *, user_id: str, token_kind: str) -> str: try: kid, private_pem = await self._key_store.get_signing_key() token = jwt.encode(payload, private_pem, algorithm=settings.JWT_ALGORITHM, headers={'kid': kid}) - self._logger.info(f'Admin access token created admin_user_id={user_id} kid={kid}') + self._logger.info(f'Admin {token_kind} token created admin_user_id={user_id} kid={kid}') return token except ApplicationException: raise @@ -57,6 +82,26 @@ class JwtService(IJwtService): sub=str(payload['sub']), type='access', role=str(payload['role']) if payload.get('role') else None, + sid=str(payload['sid']) if payload.get('sid') else None, + iat=int(payload['iat']), + nbf=int(payload['nbf']), + exp=int(payload['exp']), + iss=payload.get('iss'), + aud=payload.get('aud'), + ) + except KeyError as exception: + raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}') + + async def decode_refresh_token(self, token: str) -> RefreshTokenPayload: + payload = await self._decode_and_verify(token) + if payload.get('type') != 'refresh': + raise ApplicationException(status_code=401, message='Invalid token type') + try: + return RefreshTokenPayload( + sub=str(payload['sub']), + type='refresh', + sid=str(payload['sid']), + jti=str(payload['jti']), iat=int(payload['iat']), nbf=int(payload['nbf']), exp=int(payload['exp']), @@ -104,8 +149,13 @@ class JwtService(IJwtService): ) if 'type' not in payload: raise ApplicationException(status_code=401, message='Missing token claim: type') - if payload.get('type') == 'access' and 'role' not in payload: + token_type = payload.get('type') + if token_type == 'access' and 'role' not in payload: raise ApplicationException(status_code=401, message='Missing token claim: role') + if token_type == 'refresh': + for claim in ('sid', 'jti'): + if claim not in payload: + raise ApplicationException(status_code=401, message=f'Missing token claim: {claim}') return payload except ExpiredSignatureError: raise ApplicationException(status_code=401, message='Token expired') diff --git a/src/presentation/decorators/csrf.py b/src/presentation/decorators/csrf.py deleted file mode 100644 index 768e69e..0000000 --- a/src/presentation/decorators/csrf.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations -import inspect -from functools import wraps -from typing import Callable, Awaitable, Any, Optional, Annotated -from fastapi import Request, Header -from src.application.domain.exceptions import ApplicationException -from src.infrastructure.security import CsrfService - - -def csrf_protect( - expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None, -): - def decorator(func: Callable[..., Awaitable[Any]]): - sig = inspect.signature(func) - params = list(sig.parameters.values()) - - has_request = any(p.annotation is Request or p.name == 'request' for p in params) - if not has_request: - raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`') - - has_header = any(p.name == 'x_csrf_token' for p in params) - if not has_header: - params.append( - inspect.Parameter( - name='x_csrf_token', - kind=inspect.Parameter.KEYWORD_ONLY, - default=None, - annotation=Annotated[str | None, Header(alias='X-CSRF-Token')], - ) - ) - - @wraps(func) - async def wrapper(*args, **kwargs): - request: Request | None = kwargs.get('request') - if request is None: - for arg in args: - if isinstance(arg, Request): - request = arg - break - - if request is None: - raise ApplicationException( - status_code=500, - message='Request is required for CSRF protection', - ) - - csrf = CsrfService() - - cookie_token, _ = csrf.extract(request.cookies, request.headers) - header_token = kwargs.get('x_csrf_token') - - expected_subject = expected_subject_getter(request) if expected_subject_getter else None - csrf.verify_pair(cookie_token, header_token, expected_subject) - - kwargs.pop('x_csrf_token', None) - return await func(*args, **kwargs) - - wrapper.__signature__ = sig.replace(parameters=params) - return wrapper - - return decorator diff --git a/src/presentation/dependencies/commands.py b/src/presentation/dependencies/commands.py index 9b08d80..1adaaf8 100644 --- a/src/presentation/dependencies/commands.py +++ b/src/presentation/dependencies/commands.py @@ -5,6 +5,8 @@ from fastapi import Depends from src.application.abstractions import IUnitOfWork from src.application.commands import ( AdminLoginCommand, + AdminLogoutCommand, + AdminJwtRefreshCommand, GetAdminMeCommand, CreateOrganizationCommand, CreateOrganizationWalletsCommand, @@ -19,9 +21,10 @@ from src.application.commands import ( UpdatePurchaseRequestStatusCommand, UploadOrganizationDocumentCommand, ) -from src.application.contracts import IHashService, IJwtService, ILogger +from src.application.contracts import ICache, IHashService, IJwtService, ILogger from src.infrastructure.config import settings from src.infrastructure.storage.s3_documents_service import S3DocumentsService +from src.presentation.dependencies.cache import get_cache from src.presentation.dependencies.logger import get_logger from src.presentation.dependencies.security import get_hash_service, get_jwt_service from src.presentation.dependencies.unit_of_work import get_unit_of_work @@ -60,6 +63,24 @@ def get_admin_me_command( return GetAdminMeCommand(uow, logger) +def get_admin_logout_command( + uow: IUnitOfWork = Depends(get_unit_of_work), + jwt_service: IJwtService = Depends(get_jwt_service), + logger: ILogger = Depends(get_logger), +) -> AdminLogoutCommand: + return AdminLogoutCommand(uow, jwt_service, logger) + + +def get_admin_jwt_refresh_command( + uow: IUnitOfWork = Depends(get_unit_of_work), + hash_service: IHashService = Depends(get_hash_service), + jwt_service: IJwtService = Depends(get_jwt_service), + cache: ICache = Depends(get_cache), + logger: ILogger = Depends(get_logger), +) -> AdminJwtRefreshCommand: + return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger) + + def get_create_organization_command( uow: IUnitOfWork = Depends(get_unit_of_work), hash_service: IHashService = Depends(get_hash_service), diff --git a/uv.lock b/uv.lock index d15011b..cc1a2e8 100644 --- a/uv.lock +++ b/uv.lock @@ -69,7 +69,6 @@ dependencies = [ { name = "fastapi" }, { name = "granian" }, { name = "hvac" }, - { name = "itsdangerous" }, { name = "orjson" }, { name = "pydantic-settings" }, { name = "python-jose" }, @@ -94,7 +93,6 @@ requires-dist = [ { name = "fastapi", specifier = "==0.128.7" }, { name = "granian", specifier = "==2.6.1" }, { name = "hvac", specifier = "==2.4.0" }, - { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "orjson", specifier = "==3.11.7" }, { name = "pydantic-settings", specifier = "==2.12.0" }, { name = "python-jose", specifier = "==3.5.0" }, @@ -812,15 +810,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/7b/2edca79b359fc9f95d774616867a03ecccdf333797baf5b3eea79733918c/ijson-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:f4f7fabd653459dcb004175235f310435959b1bb5dfa8878578391c6cc9ad944", size = 55500, upload-time = "2026-02-24T03:57:20.428Z" }, ] -[[package]] -name = "itsdangerous" -version = "2.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, -] - [[package]] name = "jmespath" version = "1.1.0"