feat: add import

This commit is contained in:
2026-06-04 18:07:08 +03:00
parent 4eb2c78c03
commit 62a48c3f70
13 changed files with 104 additions and 190 deletions

View File

@@ -16,7 +16,6 @@ dependencies = [
"fastapi==0.128.7",
"granian==2.6.1",
"hvac==2.4.0",
"itsdangerous>=2.2.0",
"orjson==3.11.7",
"pydantic-settings==2.12.0",
"python-jose==3.5.0",

View File

@@ -1,4 +1,6 @@
from src.application.commands.admin_login import AdminLoginCommand
from src.application.commands.admin_logout import AdminLogoutCommand
from src.application.commands.admin_jwt_refresh import AdminJwtRefreshCommand
from src.application.commands.get_admin_me import GetAdminMeCommand
from src.application.commands.create_organization import CreateOrganizationCommand
from src.application.commands.create_organization_wallets import CreateOrganizationWalletsCommand
@@ -21,6 +23,8 @@ from src.application.commands.purchase_request_commands import (
__all__ = [
'AdminLoginCommand',
'AdminLogoutCommand',
'AdminJwtRefreshCommand',
'GetAdminMeCommand',
'CreateOrganizationCommand',
'CreateOrganizationWalletsCommand',

View File

@@ -2,6 +2,5 @@ from src.application.contracts.i_hash_service import IHashService
from src.application.contracts.i_logger import ILogger
from src.application.contracts.i_user_service import IUserService
from src.application.contracts.i_jwt_service import IJwtService
from src.application.contracts.i_csrf_service import ICsrfService
from src.application.contracts.i_cache import ICache
from src.application.contracts.i_queue_messanger import IQueueMessanger

View File

@@ -1,26 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Mapping
class ICsrfService(ABC):
@abstractmethod
def issue(self, subject: Optional[str] = None) -> str:
raise NotImplementedError
@abstractmethod
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
raise NotImplementedError
@abstractmethod
def verify_pair(
self,
cookie_token: Optional[str],
header_token: Optional[str],
expected_subject: Optional[str] = None,
) -> None:
raise NotImplementedError

View File

@@ -1,13 +1,21 @@
from abc import ABC, abstractmethod
from src.application.domain.dto import AccessTokenPayload
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
class IJwtService(ABC):
@abstractmethod
async def create_access_token(self, user_id: str, *, role: str) -> str:
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
raise NotImplementedError
@abstractmethod
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
raise NotImplementedError
@abstractmethod
async def decode_access_token(self, token: str) -> AccessTokenPayload:
raise NotImplementedError
@abstractmethod
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
raise NotImplementedError

View File

@@ -1,4 +1,4 @@
from src.application.domain.dto.admin_auth import AdminLoginDto
from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext
from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext, RefreshTokenPayload
from src.application.domain.dto.keys import JwtKeySet, JwtKeyPair
from src.application.domain.dto.user import UserCreatedDto, UserLoginDto

View File

@@ -5,6 +5,19 @@ class AccessTokenPayload(BaseModel):
sub: str
type: str
role: str | None = None
sid: str | None = None
iat: int
nbf: int
exp: int
iss: str | None = None
aud: str | None = None
class RefreshTokenPayload(BaseModel):
sub: str
type: str
sid: str
jti: str
iat: int
nbf: int
exp: int

View File

@@ -1,3 +1,2 @@
from src.infrastructure.security.jwt import JwtService
from src.infrastructure.security.csrf import CsrfService
from src.infrastructure.security.hash import HashService

View File

@@ -1,81 +0,0 @@
from __future__ import annotations
import secrets
from typing import Any, Optional, Mapping
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from src.application.contracts import ICsrfService
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
class CsrfService(ICsrfService):
COOKIE_NAME = "csrf_token"
HEADER_NAME = "X-CSRF-Token"
SALT = "csrf"
TTL_SECONDS = 3600
def __init__(self) -> None:
self._serializer = URLSafeTimedSerializer(
secret_key=settings.CSRF_SECRET_KEY,
salt=self.SALT,
)
@property
def cookie_name(self) -> str:
return self.COOKIE_NAME
@property
def header_name(self) -> str:
return self.HEADER_NAME
@property
def ttl_seconds(self) -> int:
return self.TTL_SECONDS
def issue(self, subject: Optional[str] = None) -> str:
payload = {
"sub": subject,
"nonce": secrets.token_urlsafe(32),
}
return self._serializer.dumps(payload)
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
try:
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
except SignatureExpired:
raise ApplicationException(
status_code=403,
message="CSRF token expired",
)
except BadSignature:
raise ApplicationException(
status_code=403,
message="CSRF token invalid",
)
if expected_subject is not None and data.get("sub") != expected_subject:
raise ApplicationException(
status_code=403,
message="CSRF token subject mismatch",
)
return data
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
cookie_token = cookies.get(self.COOKIE_NAME)
header_token = headers.get(self.HEADER_NAME)
return cookie_token, header_token
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
if not cookie_token or not header_token:
raise ApplicationException(
status_code=403,
message="CSRF token missing",
)
if not secrets.compare_digest(cookie_token, header_token):
raise ApplicationException(
status_code=403,
message="CSRF token mismatch",
)
self.verify(cookie_token, expected_subject=expected_subject)

View File

@@ -5,7 +5,7 @@ from datetime import datetime, timezone, timedelta
from jose import jwt, ExpiredSignatureError, JWTError
from src.application.contracts import ILogger, IJwtService
from src.application.domain.dto import AccessTokenPayload
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
from src.infrastructure.vault import JwtKeyStore
@@ -20,7 +20,7 @@ class JwtService(IJwtService):
def _issuer(self) -> str | None:
return settings.ADMIN_JWT_ISSUER
async def create_access_token(self, user_id: str, *, role: str) -> str:
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
@@ -32,15 +32,40 @@ class JwtService(IJwtService):
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
}
if sid:
payload['sid'] = sid
if self._issuer:
payload['iss'] = self._issuer
if settings.JWT_AUDIENCE:
payload['aud'] = settings.JWT_AUDIENCE
return await self._encode(payload, user_id=user_id, token_kind='access')
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
payload: dict[str, object] = {
'sub': user_id,
'type': 'refresh',
'sid': sid,
'jti': refresh_jti,
'iat': int(now.timestamp()),
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
}
if self._issuer:
payload['iss'] = self._issuer
if settings.JWT_AUDIENCE:
payload['aud'] = settings.JWT_AUDIENCE
return await self._encode(payload, user_id=user_id, token_kind='refresh')
async def _encode(self, payload: dict[str, object], *, user_id: str, token_kind: str) -> str:
try:
kid, private_pem = await self._key_store.get_signing_key()
token = jwt.encode(payload, private_pem, algorithm=settings.JWT_ALGORITHM, headers={'kid': kid})
self._logger.info(f'Admin access token created admin_user_id={user_id} kid={kid}')
self._logger.info(f'Admin {token_kind} token created admin_user_id={user_id} kid={kid}')
return token
except ApplicationException:
raise
@@ -57,6 +82,26 @@ class JwtService(IJwtService):
sub=str(payload['sub']),
type='access',
role=str(payload['role']) if payload.get('role') else None,
sid=str(payload['sid']) if payload.get('sid') else None,
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
iss=payload.get('iss'),
aud=payload.get('aud'),
)
except KeyError as exception:
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
payload = await self._decode_and_verify(token)
if payload.get('type') != 'refresh':
raise ApplicationException(status_code=401, message='Invalid token type')
try:
return RefreshTokenPayload(
sub=str(payload['sub']),
type='refresh',
sid=str(payload['sid']),
jti=str(payload['jti']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
@@ -104,8 +149,13 @@ class JwtService(IJwtService):
)
if 'type' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: type')
if payload.get('type') == 'access' and 'role' not in payload:
token_type = payload.get('type')
if token_type == 'access' and 'role' not in payload:
raise ApplicationException(status_code=401, message='Missing token claim: role')
if token_type == 'refresh':
for claim in ('sid', 'jti'):
if claim not in payload:
raise ApplicationException(status_code=401, message=f'Missing token claim: {claim}')
return payload
except ExpiredSignatureError:
raise ApplicationException(status_code=401, message='Token expired')

View File

@@ -1,61 +0,0 @@
from __future__ import annotations
import inspect
from functools import wraps
from typing import Callable, Awaitable, Any, Optional, Annotated
from fastapi import Request, Header
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.security import CsrfService
def csrf_protect(
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
):
def decorator(func: Callable[..., Awaitable[Any]]):
sig = inspect.signature(func)
params = list(sig.parameters.values())
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
if not has_request:
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
has_header = any(p.name == 'x_csrf_token' for p in params)
if not has_header:
params.append(
inspect.Parameter(
name='x_csrf_token',
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
)
)
@wraps(func)
async def wrapper(*args, **kwargs):
request: Request | None = kwargs.get('request')
if request is None:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
raise ApplicationException(
status_code=500,
message='Request is required for CSRF protection',
)
csrf = CsrfService()
cookie_token, _ = csrf.extract(request.cookies, request.headers)
header_token = kwargs.get('x_csrf_token')
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
csrf.verify_pair(cookie_token, header_token, expected_subject)
kwargs.pop('x_csrf_token', None)
return await func(*args, **kwargs)
wrapper.__signature__ = sig.replace(parameters=params)
return wrapper
return decorator

View File

@@ -5,6 +5,8 @@ from fastapi import Depends
from src.application.abstractions import IUnitOfWork
from src.application.commands import (
AdminLoginCommand,
AdminLogoutCommand,
AdminJwtRefreshCommand,
GetAdminMeCommand,
CreateOrganizationCommand,
CreateOrganizationWalletsCommand,
@@ -19,9 +21,10 @@ from src.application.commands import (
UpdatePurchaseRequestStatusCommand,
UploadOrganizationDocumentCommand,
)
from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.contracts import ICache, IHashService, IJwtService, ILogger
from src.infrastructure.config import settings
from src.infrastructure.storage.s3_documents_service import S3DocumentsService
from src.presentation.dependencies.cache import get_cache
from src.presentation.dependencies.logger import get_logger
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
from src.presentation.dependencies.unit_of_work import get_unit_of_work
@@ -60,6 +63,24 @@ def get_admin_me_command(
return GetAdminMeCommand(uow, logger)
def get_admin_logout_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
jwt_service: IJwtService = Depends(get_jwt_service),
logger: ILogger = Depends(get_logger),
) -> AdminLogoutCommand:
return AdminLogoutCommand(uow, jwt_service, logger)
def get_admin_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger),
) -> AdminJwtRefreshCommand:
return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
def get_create_organization_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),

11
uv.lock generated
View File

@@ -69,7 +69,6 @@ dependencies = [
{ name = "fastapi" },
{ name = "granian" },
{ name = "hvac" },
{ name = "itsdangerous" },
{ name = "orjson" },
{ name = "pydantic-settings" },
{ name = "python-jose" },
@@ -94,7 +93,6 @@ requires-dist = [
{ name = "fastapi", specifier = "==0.128.7" },
{ name = "granian", specifier = "==2.6.1" },
{ name = "hvac", specifier = "==2.4.0" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "orjson", specifier = "==3.11.7" },
{ name = "pydantic-settings", specifier = "==2.12.0" },
{ name = "python-jose", specifier = "==3.5.0" },
@@ -812,15 +810,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8f/7b/2edca79b359fc9f95d774616867a03ecccdf333797baf5b3eea79733918c/ijson-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:f4f7fabd653459dcb004175235f310435959b1bb5dfa8878578391c6cc9ad944", size = 55500, upload-time = "2026-02-24T03:57:20.428Z" },
]
[[package]]
name = "itsdangerous"
version = "2.2.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" },
]
[[package]]
name = "jmespath"
version = "1.1.0"