feat: add import
This commit is contained in:
@@ -16,7 +16,6 @@ dependencies = [
|
||||
"fastapi==0.128.7",
|
||||
"granian==2.6.1",
|
||||
"hvac==2.4.0",
|
||||
"itsdangerous>=2.2.0",
|
||||
"orjson==3.11.7",
|
||||
"pydantic-settings==2.12.0",
|
||||
"python-jose==3.5.0",
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from src.application.commands.admin_login import AdminLoginCommand
|
||||
from src.application.commands.admin_logout import AdminLogoutCommand
|
||||
from src.application.commands.admin_jwt_refresh import AdminJwtRefreshCommand
|
||||
from src.application.commands.get_admin_me import GetAdminMeCommand
|
||||
from src.application.commands.create_organization import CreateOrganizationCommand
|
||||
from src.application.commands.create_organization_wallets import CreateOrganizationWalletsCommand
|
||||
@@ -21,6 +23,8 @@ from src.application.commands.purchase_request_commands import (
|
||||
|
||||
__all__ = [
|
||||
'AdminLoginCommand',
|
||||
'AdminLogoutCommand',
|
||||
'AdminJwtRefreshCommand',
|
||||
'GetAdminMeCommand',
|
||||
'CreateOrganizationCommand',
|
||||
'CreateOrganizationWalletsCommand',
|
||||
|
||||
@@ -2,6 +2,5 @@ from src.application.contracts.i_hash_service import IHashService
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
from src.application.contracts.i_user_service import IUserService
|
||||
from src.application.contracts.i_jwt_service import IJwtService
|
||||
from src.application.contracts.i_csrf_service import ICsrfService
|
||||
from src.application.contracts.i_cache import ICache
|
||||
from src.application.contracts.i_queue_messanger import IQueueMessanger
|
||||
@@ -1,26 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
|
||||
class ICsrfService(ABC):
|
||||
@abstractmethod
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify_pair(
|
||||
self,
|
||||
cookie_token: Optional[str],
|
||||
header_token: Optional[str],
|
||||
expected_subject: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -1,13 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
|
||||
|
||||
|
||||
class IJwtService(ABC):
|
||||
@abstractmethod
|
||||
async def create_access_token(self, user_id: str, *, role: str) -> str:
|
||||
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.application.domain.dto.admin_auth import AdminLoginDto
|
||||
from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext
|
||||
from src.application.domain.dto.token import AccessTokenPayload, AdminAuthContext, RefreshTokenPayload
|
||||
from src.application.domain.dto.keys import JwtKeySet, JwtKeyPair
|
||||
from src.application.domain.dto.user import UserCreatedDto, UserLoginDto
|
||||
|
||||
@@ -5,6 +5,19 @@ class AccessTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
role: str | None = None
|
||||
sid: str | None = None
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
jti: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.security.csrf import CsrfService
|
||||
from src.infrastructure.security.hash import HashService
|
||||
@@ -1,81 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import secrets
|
||||
from typing import Any, Optional, Mapping
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from src.application.contracts import ICsrfService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
|
||||
|
||||
class CsrfService(ICsrfService):
|
||||
COOKIE_NAME = "csrf_token"
|
||||
HEADER_NAME = "X-CSRF-Token"
|
||||
SALT = "csrf"
|
||||
TTL_SECONDS = 3600
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._serializer = URLSafeTimedSerializer(
|
||||
secret_key=settings.CSRF_SECRET_KEY,
|
||||
salt=self.SALT,
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self) -> str:
|
||||
return self.COOKIE_NAME
|
||||
|
||||
@property
|
||||
def header_name(self) -> str:
|
||||
return self.HEADER_NAME
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> int:
|
||||
return self.TTL_SECONDS
|
||||
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
payload = {
|
||||
"sub": subject,
|
||||
"nonce": secrets.token_urlsafe(32),
|
||||
}
|
||||
return self._serializer.dumps(payload)
|
||||
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
try:
|
||||
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
|
||||
except SignatureExpired:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token expired",
|
||||
)
|
||||
except BadSignature:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token invalid",
|
||||
)
|
||||
|
||||
if expected_subject is not None and data.get("sub") != expected_subject:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token subject mismatch",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
cookie_token = cookies.get(self.COOKIE_NAME)
|
||||
header_token = headers.get(self.HEADER_NAME)
|
||||
return cookie_token, header_token
|
||||
|
||||
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
|
||||
if not cookie_token or not header_token:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token missing",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token mismatch",
|
||||
)
|
||||
|
||||
self.verify(cookie_token, expected_subject=expected_subject)
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime, timezone, timedelta
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
@@ -20,7 +20,7 @@ class JwtService(IJwtService):
|
||||
def _issuer(self) -> str | None:
|
||||
return settings.ADMIN_JWT_ISSUER
|
||||
|
||||
async def create_access_token(self, user_id: str, *, role: str) -> str:
|
||||
async def create_access_token(self, user_id: str, *, role: str, sid: str | None = None) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
|
||||
|
||||
@@ -32,15 +32,40 @@ class JwtService(IJwtService):
|
||||
'nbf': int(now.timestamp()),
|
||||
'exp': int(exp.timestamp()),
|
||||
}
|
||||
if sid:
|
||||
payload['sid'] = sid
|
||||
if self._issuer:
|
||||
payload['iss'] = self._issuer
|
||||
if settings.JWT_AUDIENCE:
|
||||
payload['aud'] = settings.JWT_AUDIENCE
|
||||
|
||||
return await self._encode(payload, user_id=user_id, token_kind='access')
|
||||
|
||||
async def create_refresh_token(self, user_id: str, *, sid: str, refresh_jti: str) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
|
||||
|
||||
payload: dict[str, object] = {
|
||||
'sub': user_id,
|
||||
'type': 'refresh',
|
||||
'sid': sid,
|
||||
'jti': refresh_jti,
|
||||
'iat': int(now.timestamp()),
|
||||
'nbf': int(now.timestamp()),
|
||||
'exp': int(exp.timestamp()),
|
||||
}
|
||||
if self._issuer:
|
||||
payload['iss'] = self._issuer
|
||||
if settings.JWT_AUDIENCE:
|
||||
payload['aud'] = settings.JWT_AUDIENCE
|
||||
|
||||
return await self._encode(payload, user_id=user_id, token_kind='refresh')
|
||||
|
||||
async def _encode(self, payload: dict[str, object], *, user_id: str, token_kind: str) -> str:
|
||||
try:
|
||||
kid, private_pem = await self._key_store.get_signing_key()
|
||||
token = jwt.encode(payload, private_pem, algorithm=settings.JWT_ALGORITHM, headers={'kid': kid})
|
||||
self._logger.info(f'Admin access token created admin_user_id={user_id} kid={kid}')
|
||||
self._logger.info(f'Admin {token_kind} token created admin_user_id={user_id} kid={kid}')
|
||||
return token
|
||||
except ApplicationException:
|
||||
raise
|
||||
@@ -57,6 +82,26 @@ class JwtService(IJwtService):
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
role=str(payload['role']) if payload.get('role') else None,
|
||||
sid=str(payload['sid']) if payload.get('sid') else None,
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
if payload.get('type') != 'refresh':
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
try:
|
||||
return RefreshTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='refresh',
|
||||
sid=str(payload['sid']),
|
||||
jti=str(payload['jti']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
@@ -104,8 +149,13 @@ class JwtService(IJwtService):
|
||||
)
|
||||
if 'type' not in payload:
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
if payload.get('type') == 'access' and 'role' not in payload:
|
||||
token_type = payload.get('type')
|
||||
if token_type == 'access' and 'role' not in payload:
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: role')
|
||||
if token_type == 'refresh':
|
||||
for claim in ('sid', 'jti'):
|
||||
if claim not in payload:
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {claim}')
|
||||
return payload
|
||||
except ExpiredSignatureError:
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, Any, Optional, Annotated
|
||||
from fastapi import Request, Header
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.security import CsrfService
|
||||
|
||||
|
||||
def csrf_protect(
|
||||
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
|
||||
):
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
|
||||
if not has_request:
|
||||
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
|
||||
|
||||
has_header = any(p.name == 'x_csrf_token' for p in params)
|
||||
if not has_header:
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
name='x_csrf_token',
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
|
||||
)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request | None = kwargs.get('request')
|
||||
if request is None:
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
raise ApplicationException(
|
||||
status_code=500,
|
||||
message='Request is required for CSRF protection',
|
||||
)
|
||||
|
||||
csrf = CsrfService()
|
||||
|
||||
cookie_token, _ = csrf.extract(request.cookies, request.headers)
|
||||
header_token = kwargs.get('x_csrf_token')
|
||||
|
||||
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
|
||||
csrf.verify_pair(cookie_token, header_token, expected_subject)
|
||||
|
||||
kwargs.pop('x_csrf_token', None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=params)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -5,6 +5,8 @@ from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands import (
|
||||
AdminLoginCommand,
|
||||
AdminLogoutCommand,
|
||||
AdminJwtRefreshCommand,
|
||||
GetAdminMeCommand,
|
||||
CreateOrganizationCommand,
|
||||
CreateOrganizationWalletsCommand,
|
||||
@@ -19,9 +21,10 @@ from src.application.commands import (
|
||||
UpdatePurchaseRequestStatusCommand,
|
||||
UploadOrganizationDocumentCommand,
|
||||
)
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
||||
from src.application.contracts import ICache, IHashService, IJwtService, ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.storage.s3_documents_service import S3DocumentsService
|
||||
from src.presentation.dependencies.cache import get_cache
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
|
||||
from src.presentation.dependencies.unit_of_work import get_unit_of_work
|
||||
@@ -60,6 +63,24 @@ def get_admin_me_command(
|
||||
return GetAdminMeCommand(uow, logger)
|
||||
|
||||
|
||||
def get_admin_logout_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> AdminLogoutCommand:
|
||||
return AdminLogoutCommand(uow, jwt_service, logger)
|
||||
|
||||
|
||||
def get_admin_jwt_refresh_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> AdminJwtRefreshCommand:
|
||||
return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
|
||||
|
||||
|
||||
def get_create_organization_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
|
||||
11
uv.lock
generated
11
uv.lock
generated
@@ -69,7 +69,6 @@ dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "granian" },
|
||||
{ name = "hvac" },
|
||||
{ name = "itsdangerous" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "python-jose" },
|
||||
@@ -94,7 +93,6 @@ requires-dist = [
|
||||
{ name = "fastapi", specifier = "==0.128.7" },
|
||||
{ name = "granian", specifier = "==2.6.1" },
|
||||
{ name = "hvac", specifier = "==2.4.0" },
|
||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||
{ name = "orjson", specifier = "==3.11.7" },
|
||||
{ name = "pydantic-settings", specifier = "==2.12.0" },
|
||||
{ name = "python-jose", specifier = "==3.5.0" },
|
||||
@@ -812,15 +810,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/7b/2edca79b359fc9f95d774616867a03ecccdf333797baf5b3eea79733918c/ijson-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:f4f7fabd653459dcb004175235f310435959b1bb5dfa8878578391c6cc9ad944", size = 55500, upload-time = "2026-02-24T03:57:20.428Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itsdangerous"
|
||||
version = "2.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jmespath"
|
||||
version = "1.1.0"
|
||||
|
||||
Reference in New Issue
Block a user