init commit
This commit is contained in:
1
src/application/abstractions/__init__.py
Normal file
1
src/application/abstractions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.abstractions.i_unit_of_work import IUnitOfWork
|
||||
19
src/application/abstractions/i_unit_of_work.py
Normal file
19
src/application/abstractions/i_unit_of_work.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, runtime_checkable
|
||||
from src.application.abstractions.repositories import IKycRepository,IUserRepository
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IUnitOfWork(Protocol):
|
||||
async def __aenter__(self) -> 'IUnitOfWork': ...
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository: ...
|
||||
|
||||
@property
|
||||
def kyc_repository(self) -> IKycRepository: ...
|
||||
|
||||
2
src/application/abstractions/repositories/__init__.py
Normal file
2
src/application/abstractions/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.abstractions.repositories.i_kyc_repository import IKycRepository
|
||||
from src.application.abstractions.repositories.i_user_repository import IUserRepository
|
||||
@@ -0,0 +1,47 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from src.application.domain.entities import KycEntity
|
||||
|
||||
|
||||
class IKycRepository(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def create_started_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_token: str | None,
|
||||
client_user_token: str | None,
|
||||
link: str | None,
|
||||
qr_code: str | None,
|
||||
expires_at: datetime,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def update_session_result(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_token: str,
|
||||
status: str,
|
||||
done_state: bool | None,
|
||||
set_id: str | None,
|
||||
result_data: Any,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def expire_started_sessions(self,*,user_id: str,now: datetime) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_active_session(self,*,user_id: str,now: datetime) -> KycEntity | None:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,34 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from datetime import date
|
||||
from src.application.domain.entities import UserEntity
|
||||
|
||||
|
||||
class IUserRepository(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, email: str, password_hash: str) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def update_kyc_data(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
birth_date: date,
|
||||
middle_name: str | None,
|
||||
inn: str | None,
|
||||
) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
1
src/application/commands/__init__.py
Normal file
1
src/application/commands/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.commands.create_kyc_command import CompleteKycCommand,GetKycSessionCommand,PassKycCommand
|
||||
183
src/application/commands/create_kyc_command.py
Normal file
183
src/application/commands/create_kyc_command.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime,timedelta,timezone
|
||||
import orjson
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IBeorgService,ICache,ILogger,IQueueMessanger
|
||||
from src.application.domain.dto import BeorgKycCreateResponse,BeorgKycResultResponse,KycSessionResponse
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.services import ensure_adult,extract_personal_data,parse_birth_date
|
||||
|
||||
|
||||
KYC_SESSION_TTL = 3600
|
||||
|
||||
|
||||
class PassKycCommand:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
cache: ICache,
|
||||
beorg_service: IBeorgService,
|
||||
) -> None:
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
self._beorg_service = beorg_service
|
||||
|
||||
|
||||
async def __call__(self,user_id: str) -> BeorgKycCreateResponse:
|
||||
result = await self._beorg_service.create_identification(client_user_token=user_id)
|
||||
expires_at = _utc_now() + timedelta(seconds=KYC_SESSION_TTL)
|
||||
async with self._unit_of_work as unit_of_work:
|
||||
await unit_of_work.kyc_repository.create_started_session(
|
||||
user_id=user_id,
|
||||
user_token=result.user_token,
|
||||
client_user_token=result.client_user_token,
|
||||
link=result.link,
|
||||
qr_code=result.qr_code,
|
||||
expires_at=expires_at,
|
||||
error=result.error,
|
||||
)
|
||||
await self._cache.set(f'kyc:session:{user_id}',result.model_dump_json(),ttl=KYC_SESSION_TTL)
|
||||
self._logger.info(f'KYC started for user {user_id}')
|
||||
return result
|
||||
|
||||
|
||||
class CompleteKycCommand:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
cache: ICache,
|
||||
beorg_service: IBeorgService,
|
||||
queue_messanger: IQueueMessanger,
|
||||
verified_queue: str,
|
||||
) -> None:
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
self._beorg_service = beorg_service
|
||||
self._queue_messanger = queue_messanger
|
||||
self._verified_queue = verified_queue
|
||||
|
||||
|
||||
async def __call__(self,user_id: str) -> BeorgKycResultResponse:
|
||||
session = await self._get_session(user_id)
|
||||
if not session.user_token:
|
||||
raise ApplicationException(status_code=409,message='KYC session has no user token')
|
||||
|
||||
result = await self._beorg_service.get_result(user_token=session.user_token)
|
||||
if result.done_state is None:
|
||||
raise ApplicationException(status_code=409,message='KYC is not completed yet')
|
||||
if result.done_state is False:
|
||||
async with self._unit_of_work as unit_of_work:
|
||||
await unit_of_work.kyc_repository.update_session_result(
|
||||
user_id=user_id,
|
||||
user_token=session.user_token,
|
||||
status='failed',
|
||||
done_state=result.done_state,
|
||||
set_id=result.set_id,
|
||||
result_data=result.data,
|
||||
error='KYC failed',
|
||||
)
|
||||
raise ApplicationException(status_code=400,message='KYC failed')
|
||||
|
||||
personal_data = extract_personal_data(result.data)
|
||||
birth_date = parse_birth_date(personal_data.birth_date)
|
||||
ensure_adult(birth_date)
|
||||
|
||||
async with self._unit_of_work as unit_of_work:
|
||||
user = await unit_of_work.user_repository.update_kyc_data(
|
||||
user_id=user_id,
|
||||
first_name=personal_data.first_name,
|
||||
last_name=personal_data.last_name,
|
||||
middle_name=personal_data.middle_name,
|
||||
birth_date=birth_date,
|
||||
inn=personal_data.inn,
|
||||
)
|
||||
await unit_of_work.kyc_repository.update_session_result(
|
||||
user_id=user_id,
|
||||
user_token=session.user_token,
|
||||
status='completed',
|
||||
done_state=result.done_state,
|
||||
set_id=result.set_id,
|
||||
result_data=result.data,
|
||||
error=None,
|
||||
)
|
||||
await self._cache.set_user(user_id,user,ttl=KYC_SESSION_TTL)
|
||||
await self._cache.delete(f'kyc:session:{user_id}')
|
||||
await self._queue_messanger.publish_to_queue(
|
||||
self._verified_queue,
|
||||
{
|
||||
'user_id': user_id,
|
||||
'kyc_verified': True,
|
||||
'first_name': user.first_name,
|
||||
'last_name': user.last_name,
|
||||
'middle_name': user.middle_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'inn': user.inn,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
},
|
||||
message_id=str(ULID()),
|
||||
correlation_id=user_id,
|
||||
)
|
||||
self._logger.info(f'KYC completed for user {user_id}')
|
||||
return result
|
||||
|
||||
|
||||
async def _get_session(self,user_id: str) -> BeorgKycCreateResponse:
|
||||
raw = await self._cache.get(f'kyc:session:{user_id}')
|
||||
if raw is not None:
|
||||
return BeorgKycCreateResponse.model_validate(orjson.loads(raw))
|
||||
|
||||
now = _utc_now()
|
||||
async with self._unit_of_work as unit_of_work:
|
||||
await unit_of_work.kyc_repository.expire_started_sessions(user_id=user_id,now=now)
|
||||
session = await unit_of_work.kyc_repository.get_active_session(user_id=user_id,now=now)
|
||||
if session is not None:
|
||||
return BeorgKycCreateResponse(
|
||||
status=True,
|
||||
link=session.link,
|
||||
user_token=session.user_token,
|
||||
client_user_token=session.client_user_token,
|
||||
qr_code=session.qr_code,
|
||||
)
|
||||
raise ApplicationException(status_code=404,message='KYC session expired')
|
||||
|
||||
|
||||
class GetKycSessionCommand:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
unit_of_work: IUnitOfWork,
|
||||
) -> None:
|
||||
self._unit_of_work = unit_of_work
|
||||
|
||||
|
||||
async def __call__(self,user_id: str) -> KycSessionResponse:
|
||||
now = _utc_now()
|
||||
async with self._unit_of_work as unit_of_work:
|
||||
await unit_of_work.kyc_repository.expire_started_sessions(user_id=user_id,now=now)
|
||||
session = await unit_of_work.kyc_repository.get_active_session(user_id=user_id,now=now)
|
||||
if session is None or session.expires_at is None:
|
||||
raise ApplicationException(status_code=404,message='KYC session expired')
|
||||
|
||||
expires_in = max(int((session.expires_at - now).total_seconds()),0)
|
||||
return KycSessionResponse(
|
||||
status=session.status or 'started',
|
||||
link=session.link,
|
||||
qr_code=session.qr_code,
|
||||
user_token=session.user_token,
|
||||
expires_at=session.expires_at,
|
||||
expires_in=expires_in,
|
||||
)
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
7
src/application/contracts/__init__.py
Normal file
7
src/application/contracts/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
from src.application.contracts.i_jwt_service import IJwtService
|
||||
from src.application.contracts.i_csrf_service import ICsrfService
|
||||
from src.application.contracts.i_cache import ICache
|
||||
from src.application.contracts.i_hash_service import IHashService
|
||||
from src.application.contracts.i_queue_messanger import IQueueMessanger
|
||||
from src.application.contracts.i_beorg_service import IBeorgService
|
||||
14
src/application/contracts/i_beorg_service.py
Normal file
14
src/application/contracts/i_beorg_service.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from abc import ABC,abstractmethod
|
||||
from src.application.domain.dto import BeorgKycCreateResponse,BeorgKycResultResponse
|
||||
|
||||
|
||||
class IBeorgService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def create_identification(self,client_user_token: str) -> BeorgKycCreateResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_result(self,user_token: str) -> BeorgKycResultResponse:
|
||||
raise NotImplementedError
|
||||
34
src/application/contracts/i_cache.py
Normal file
34
src/application/contracts/i_cache.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class ICache(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def hget(self, key: str, field: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
raise NotImplementedError
|
||||
23
src/application/contracts/i_csrf_service.py
Normal file
23
src/application/contracts/i_csrf_service.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
|
||||
class ICsrfService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify_pair(
|
||||
self,
|
||||
cookie_token: Optional[str],
|
||||
header_token: Optional[str],
|
||||
expected_subject: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
13
src/application/contracts/i_hash_service.py
Normal file
13
src/application/contracts/i_hash_service.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from abc import ABC,abstractmethod
|
||||
|
||||
|
||||
class IHashService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def hash(self,value: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def verify(self,value: str,hashed_value: str) -> bool:
|
||||
raise NotImplementedError
|
||||
10
src/application/contracts/i_jwt_service.py
Normal file
10
src/application/contracts/i_jwt_service.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
|
||||
|
||||
class IJwtService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
raise NotImplementedError
|
||||
68
src/application/contracts/i_logger.py
Normal file
68
src/application/contracts/i_logger.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Protocol, Optional, Callable
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
|
||||
|
||||
class ILogger(Protocol):
|
||||
"""Interface for synchronous logger with ContextVar support for trace_id."""
|
||||
|
||||
log_format: LogFormat
|
||||
min_level: LogLevel
|
||||
id_generator: Optional[Callable[[], str]]
|
||||
instance_id: str
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
"""Set log format using LogFormat enum"""
|
||||
...
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
"""Set minimum log level"""
|
||||
...
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
"""Create and set new trace_id in context"""
|
||||
...
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
"""Set existing trace_id in context"""
|
||||
...
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
"""Get current trace_id from context"""
|
||||
...
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
"""Clear the trace_id in the context"""
|
||||
...
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
"""Set service instance id (ULID recommended)"""
|
||||
...
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
"""Get current service instance id"""
|
||||
...
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log debug message"""
|
||||
...
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log info message"""
|
||||
...
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log warning message"""
|
||||
...
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log error message"""
|
||||
...
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Log critical message"""
|
||||
...
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
"""Log exception with traceback"""
|
||||
...
|
||||
40
src/application/contracts/i_queue_messanger.py
Normal file
40
src/application/contracts/i_queue_messanger.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
class IQueueMessanger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
3
src/application/domain/dto/__init__.py
Normal file
3
src/application/domain/dto/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.application.domain.dto.token import AccessTokenPayload, AuthContext
|
||||
from src.application.domain.dto.keys import JwtPublicKey, JwtPublicKeySet
|
||||
from src.application.domain.dto.beorg import BeorgKycCreateResponse,BeorgKycResultResponse,KycPersonalData,KycSessionResponse
|
||||
37
src/application/domain/dto/beorg.py
Normal file
37
src/application/domain/dto/beorg.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BeorgKycCreateResponse(BaseModel):
|
||||
status: bool
|
||||
error: str | None = None
|
||||
link: str | None = None
|
||||
user_token: str | None = None
|
||||
client_user_token: str | None = None
|
||||
qr_code: str | None = None
|
||||
|
||||
|
||||
class BeorgKycResultResponse(BaseModel):
|
||||
done_state: bool | None = None
|
||||
user_token: str
|
||||
client_user_token: str | None = None
|
||||
set_id: str | None = None
|
||||
data: Any = None
|
||||
|
||||
|
||||
class KycPersonalData(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
birth_date: str
|
||||
middle_name: str | None = None
|
||||
inn: str | None = None
|
||||
|
||||
|
||||
class KycSessionResponse(BaseModel):
|
||||
status: str
|
||||
link: str | None = None
|
||||
qr_code: str | None = None
|
||||
user_token: str | None = None
|
||||
expires_at: datetime
|
||||
expires_in: int
|
||||
20
src/application/domain/dto/keys.py
Normal file
20
src/application/domain/dto/keys.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKey:
|
||||
kid: str
|
||||
public_key_pem: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKeySet:
|
||||
active: JwtPublicKey
|
||||
previous: Optional[JwtPublicKey] = None
|
||||
|
||||
def public_keys_by_kid(self) -> Dict[str, str]:
|
||||
out = {self.active.kid: self.active.public_key_pem}
|
||||
if self.previous:
|
||||
out[self.previous.kid] = self.previous.public_key_pem
|
||||
return out
|
||||
18
src/application/domain/dto/token.py
Normal file
18
src/application/domain/dto/token.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AccessTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
user_id: str
|
||||
sid: str
|
||||
token: AccessTokenPayload
|
||||
5
src/application/domain/entities/__init__.py
Normal file
5
src/application/domain/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.application.domain.entities.kyc import KycEntity
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
__all__ = ['KycEntity','UserEntity']
|
||||
23
src/application/domain/entities/kyc.py
Normal file
23
src/application/domain/entities/kyc.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class KycEntity:
|
||||
id: str | None = None
|
||||
user_id: str | None = None
|
||||
user_token: str | None = None
|
||||
client_user_token: str | None = None
|
||||
link: str | None = None
|
||||
qr_code: str | None = None
|
||||
status: str | None = None
|
||||
done_state: bool | None = None
|
||||
set_id: str | None = None
|
||||
error: str | None = None
|
||||
result_data: Any = None
|
||||
expires_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
30
src/application/domain/entities/user.py
Normal file
30
src/application/domain/entities/user.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserEntity:
|
||||
id: str | None = None
|
||||
email: str | None = None
|
||||
password_hash: str | None = None
|
||||
|
||||
first_name: str | None = None
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
|
||||
crypto_wallet: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
inn: str | None = None
|
||||
|
||||
kyc_verified: bool | None = None
|
||||
is_deleted: bool | None = None
|
||||
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
2
src/application/domain/enums/__init__.py
Normal file
2
src/application/domain/enums/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
6
src/application/domain/enums/log_format.py
Normal file
6
src/application/domain/enums/log_format.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogFormat(str,Enum):
|
||||
JSON = 'json'
|
||||
TEXT = 'text'
|
||||
10
src/application/domain/enums/log_level.py
Normal file
10
src/application/domain/enums/log_level.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARNING = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
EXCEPTION = 60
|
||||
1
src/application/domain/exceptions/__init__.py
Normal file
1
src/application/domain/exceptions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.domain.exceptions.application_exceptions import ApplicationException
|
||||
18
src/application/domain/exceptions/application_exceptions.py
Normal file
18
src/application/domain/exceptions/application_exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class ApplicationException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
return f'{self.status_code}: {self.message}'
|
||||
1
src/application/services/__init__.py
Normal file
1
src/application/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.services.kyc_personal_data import ensure_adult,extract_personal_data,parse_birth_date
|
||||
78
src/application/services/kyc_personal_data.py
Normal file
78
src/application/services/kyc_personal_data.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
from datetime import date,datetime
|
||||
from typing import Any
|
||||
from src.application.domain.dto import KycPersonalData
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
FIELD_ALIASES = {
|
||||
'first_name': {'first_name','name','given_name','имя'},
|
||||
'last_name': {'last_name','surname','family_name','фамилия'},
|
||||
'middle_name': {'middle_name','patronymic','отчество'},
|
||||
'birth_date': {'birth_date','birthdate','date_birth','birthday','дата рождения'},
|
||||
'inn': {'inn','tax_id','инн'},
|
||||
}
|
||||
|
||||
|
||||
def extract_personal_data(data: Any) -> KycPersonalData:
|
||||
values: dict[str,str] = {}
|
||||
for key,value in _walk(data):
|
||||
normalized = _normalize_key(key)
|
||||
for field,aliases in FIELD_ALIASES.items():
|
||||
if field not in values and normalized in aliases and value not in (None,''):
|
||||
values[field] = str(value).strip()
|
||||
|
||||
missing = [field for field in ('first_name','last_name','birth_date') if not values.get(field)]
|
||||
if missing:
|
||||
raise ApplicationException(status_code=422,message='KYC personal data is incomplete')
|
||||
|
||||
return KycPersonalData(
|
||||
first_name=values['first_name'],
|
||||
last_name=values['last_name'],
|
||||
middle_name=values.get('middle_name'),
|
||||
birth_date=str(_parse_date(values['birth_date'])),
|
||||
inn=values.get('inn'),
|
||||
)
|
||||
|
||||
|
||||
def ensure_adult(birth_date: date) -> None:
|
||||
today = date.today()
|
||||
try:
|
||||
adult_from = date(today.year - 18,today.month,today.day)
|
||||
except ValueError:
|
||||
adult_from = date(today.year - 18,2,28)
|
||||
if birth_date > adult_from:
|
||||
raise ApplicationException(status_code=403,message='KYC is unavailable for users under 18')
|
||||
|
||||
|
||||
def parse_birth_date(value: str) -> date:
|
||||
return _parse_date(value)
|
||||
|
||||
|
||||
def _walk(data: Any) -> list[tuple[str,Any]]:
|
||||
items: list[tuple[str,Any]] = []
|
||||
if isinstance(data,dict):
|
||||
for key,value in data.items():
|
||||
if isinstance(value,dict | list):
|
||||
items.extend(_walk(value))
|
||||
else:
|
||||
items.append((str(key),value))
|
||||
elif isinstance(data,list):
|
||||
for item in data:
|
||||
items.extend(_walk(item))
|
||||
return items
|
||||
|
||||
|
||||
def _normalize_key(key: str) -> str:
|
||||
return key.strip().lower().replace('-','_').replace(' ','_')
|
||||
|
||||
|
||||
def _parse_date(value: str) -> date:
|
||||
clean = value.strip()
|
||||
formats = ('%Y-%m-%d','%d.%m.%Y','%d-%m-%Y','%d/%m/%Y','%Y.%m.%d')
|
||||
for date_format in formats:
|
||||
try:
|
||||
return datetime.strptime(clean,date_format).date()
|
||||
except ValueError:
|
||||
continue
|
||||
raise ApplicationException(status_code=422,message='KYC birth date has invalid format')
|
||||
1
src/infrastructure/beorg/__init__.py
Normal file
1
src/infrastructure/beorg/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.beorg.client import BeorgService
|
||||
83
src/infrastructure/beorg/client.py
Normal file
83
src/infrastructure/beorg/client.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
import aiohttp
|
||||
from src.application.contracts import IBeorgService
|
||||
from src.application.domain.dto import BeorgKycCreateResponse,BeorgKycResultResponse
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
class BeorgService(IBeorgService):
|
||||
BASE_URL = 'https://webapp.beorg.ru'
|
||||
CREATE_ENDPOINT = '/kyc/create'
|
||||
GET_RESULT_ENDPOINT = '/kyc/get_result'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
project_id: str,
|
||||
machine_uid: str,
|
||||
token: str,
|
||||
process_info: list[dict[str,Any]],
|
||||
timeout: int = 15,
|
||||
) -> None:
|
||||
self._project_id = project_id
|
||||
self._machine_uid = machine_uid
|
||||
self._token = token
|
||||
self._process_info = process_info
|
||||
self._expires = 3600
|
||||
self._timeout = timeout
|
||||
|
||||
|
||||
async def create_identification(self,client_user_token: str) -> BeorgKycCreateResponse:
|
||||
self._ensure_configured()
|
||||
payload: dict[str,Any] = {
|
||||
'project_id': self._project_id,
|
||||
'machine_uid': self._machine_uid,
|
||||
'token': self._token,
|
||||
'process_info': self._process_info,
|
||||
'client_user_token': client_user_token,
|
||||
'expires': self._expires,
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self._timeout)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f'{self.BASE_URL}{self.CREATE_ENDPOINT}',
|
||||
json=payload,
|
||||
headers={'Content-Type': 'application/json'},
|
||||
) as response:
|
||||
data = await response.json(content_type=None)
|
||||
|
||||
if response.status >= 500:
|
||||
raise ApplicationException(status_code=502,message='Beorg service unavailable')
|
||||
|
||||
result = BeorgKycCreateResponse.model_validate(data)
|
||||
if not result.status:
|
||||
raise ApplicationException(status_code=400,message=result.error or 'Beorg rejected kyc request')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_result(self,user_token: str) -> BeorgKycResultResponse:
|
||||
self._ensure_configured()
|
||||
timeout = aiohttp.ClientTimeout(total=self._timeout)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(
|
||||
f'{self.BASE_URL}{self.GET_RESULT_ENDPOINT}',
|
||||
params={
|
||||
'token': self._token,
|
||||
'user_token': user_token,
|
||||
},
|
||||
headers={'Content-Type': 'application/json'},
|
||||
) as response:
|
||||
data = await response.json(content_type=None)
|
||||
|
||||
if response.status >= 500:
|
||||
raise ApplicationException(status_code=502,message='Beorg service unavailable')
|
||||
|
||||
return BeorgKycResultResponse.model_validate(data)
|
||||
|
||||
|
||||
def _ensure_configured(self) -> None:
|
||||
if not self._project_id or not self._machine_uid or not self._token or not self._process_info:
|
||||
raise ApplicationException(status_code=500,message='Beorg service is not configured')
|
||||
5
src/infrastructure/cache/__init__.py
vendored
Normal file
5
src/infrastructure/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.infrastructure.cache.client import create_redis_client
|
||||
from src.infrastructure.cache.keydb_client import KeydbCache
|
||||
from src.infrastructure.cache.remote_cache import RemoteCache
|
||||
|
||||
__all__ = ['create_redis_client', 'KeydbCache', 'RemoteCache']
|
||||
5
src/infrastructure/cache/client.py
vendored
Normal file
5
src/infrastructure/cache/client.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
from redis.asyncio import Redis
|
||||
|
||||
|
||||
def create_redis_client(url: str) -> Redis:
|
||||
return Redis.from_url(url,decode_responses=True)
|
||||
55
src/infrastructure/cache/keydb_client.py
vendored
Normal file
55
src/infrastructure/cache/keydb_client.py
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class KeydbCache(ICache):
|
||||
USER_PREFIX = 'user:me'
|
||||
|
||||
def __init__(self, redis_client: Redis):
|
||||
self._r = redis_client
|
||||
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl))
|
||||
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl, nx=True))
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
async def hget(self, key: str, field: str) -> str | None:
|
||||
return await self._r.hget(key, field)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
return (await self._r.delete(key)) > 0
|
||||
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
|
||||
if raw is None:
|
||||
return None
|
||||
return orjson.loads(raw)
|
||||
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
data = orjson.dumps({
|
||||
'id': user.id,
|
||||
'email': user.email,
|
||||
'first_name': user.first_name,
|
||||
'middle_name': user.middle_name,
|
||||
'last_name': user.last_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'crypto_wallet': user.crypto_wallet,
|
||||
'phone': user.phone,
|
||||
'bik': user.bik,
|
||||
'account_number': user.account_number,
|
||||
'card_number': user.card_number,
|
||||
'inn': user.inn,
|
||||
'kyc_verified': user.kyc_verified,
|
||||
'is_deleted': user.is_deleted,
|
||||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||||
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
})
|
||||
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)
|
||||
65
src/infrastructure/cache/remote_cache.py
vendored
Normal file
65
src/infrastructure/cache/remote_cache.py
vendored
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class RemoteCache(ICache):
|
||||
|
||||
|
||||
USER_PREFIX = 'user:me'
|
||||
|
||||
|
||||
def __init__(self,redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
|
||||
async def set(self,key: str,value: str,ttl: int) -> bool:
|
||||
return bool(await self._r.set(key,value,ex=ttl))
|
||||
|
||||
|
||||
async def set_nx(self,key: str,value: str,ttl: int) -> bool:
|
||||
return bool(await self._r.set(key,value,ex=ttl,nx=True))
|
||||
|
||||
|
||||
async def get(self,key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
|
||||
async def hget(self,key: str,field: str) -> str | None:
|
||||
return await self._r.hget(key,field)
|
||||
|
||||
|
||||
async def delete(self,key: str) -> bool:
|
||||
return (await self._r.delete(key)) > 0
|
||||
|
||||
|
||||
async def get_user(self,user_id: str) -> dict | None:
|
||||
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
|
||||
if raw is None:
|
||||
return None
|
||||
return orjson.loads(raw)
|
||||
|
||||
|
||||
async def set_user(self,user_id: str,user: UserEntity,ttl: int = 300) -> None:
|
||||
data = orjson.dumps({
|
||||
'id': user.id,
|
||||
'email': user.email,
|
||||
'first_name': user.first_name,
|
||||
'middle_name': user.middle_name,
|
||||
'last_name': user.last_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'crypto_wallet': user.crypto_wallet,
|
||||
'phone': user.phone,
|
||||
'bik': user.bik,
|
||||
'account_number': user.account_number,
|
||||
'card_number': user.card_number,
|
||||
'inn': user.inn,
|
||||
'kyc_verified': user.kyc_verified,
|
||||
'is_deleted': user.is_deleted,
|
||||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||||
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
})
|
||||
await self._r.set(f'{self.USER_PREFIX}:{user_id}',data,ex=ttl)
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import Settings,get_settings,settings
|
||||
132
src/infrastructure/config/settings.py
Normal file
132
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from functools import lru_cache
|
||||
import json
|
||||
from typing import Any
|
||||
from pydantic import Field,PrivateAttr
|
||||
from pydantic_settings import BaseSettings,SettingsConfigDict
|
||||
from src.infrastructure.vault import VaultClient
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file='.env',extra='ignore')
|
||||
_vault_beorg_secrets: dict[str,Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
DOCS_USERNAME: str = 'admin'
|
||||
DOCS_PASSWORD: str = 'admin'
|
||||
KEYDB_URL: str = 'redis://localhost:6379/0'
|
||||
VAULT_ADDR: str = 'https://corp.vault.elcsa.ru'
|
||||
VAULT_ROLE_ID: str = ''
|
||||
VAULT_SECRET_ID: str = ''
|
||||
VAULT_NAMESPACE: str | None = None
|
||||
VAULT_MOUNT_POINT: str = 'dev-secrets'
|
||||
VAULT_APP_SECRET_PATH: str = 'app'
|
||||
VAULT_BEORG_SECRET_PATH: str = 'beorg'
|
||||
VAULT_DATABASE_SECRET_PATH: str = 'database'
|
||||
VAULT_JWT_SECRET_PATH: str = 'jwt'
|
||||
VAULT_RABBIT_SECRET_PATH: str = 'rabbitmq'
|
||||
VAULT_DOCS_SECRET_PATH: str = 'docs'
|
||||
VAULT_JWT_KID_PATH: str = 'jwt/kid'
|
||||
VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids'
|
||||
JWT_KEYS_REFRESH_SECONDS: int = 300
|
||||
JWT_ALGORITHM: str = 'RS256'
|
||||
JWT_AUDIENCE: str | None = None
|
||||
JWT_ISSUER: str | None = None
|
||||
RABBIT_URL: str = 'amqp://guest:guest@localhost:5672/'
|
||||
RABBIT_CRYPTO_TRANSFER_COMPLETED_QUEUE: str = 'crypto_transfer_completed'
|
||||
RABBIT_KYC_VERIFIED_QUEUE: str = 'kyc_verified'
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
DATABASE_URL: str = 'postgresql+asyncpg://postgres:postgres@localhost:5432/kyc'
|
||||
DATABASE_POOL_SIZE: int = 5
|
||||
DATABASE_MAX_OVERFLOW: int = 10
|
||||
DATABASE_POOL_TIMEOUT: int = 30
|
||||
DATABASE_POOL_RECYCLE: int = 1800
|
||||
DATABASE_ECHO: bool = False
|
||||
EXCLUDED_PATHS: tuple[str,...] = ('/docs','/redoc','/openapi.json','/ping')
|
||||
BEORG_TIMEOUT: int = 15
|
||||
BEORG_PROCESS_INFO: list[dict[str,Any]] = Field(default_factory=lambda: [
|
||||
{
|
||||
'key': 'SELFIE1',
|
||||
'type': 'SELFIE',
|
||||
'options': {
|
||||
'stages': [
|
||||
'biometry_liveness',
|
||||
],
|
||||
},
|
||||
'attempts': 3,
|
||||
},
|
||||
{
|
||||
'key': 'PASSPORT1',
|
||||
'type': 'PASSPORT',
|
||||
'options': {
|
||||
'stages': [
|
||||
'verification',
|
||||
'biometry_match',
|
||||
],
|
||||
'relation': {
|
||||
'biometry_match': 'SELFIE1',
|
||||
},
|
||||
},
|
||||
'attempts': 3,
|
||||
},
|
||||
])
|
||||
|
||||
|
||||
@property
|
||||
def BEORG_PROJECT_ID(self) -> str:
|
||||
return self._get_beorg_secret('project_id','BEORG_PROJECT_ID')
|
||||
|
||||
|
||||
@property
|
||||
def BEORG_MACHINE_UID(self) -> str:
|
||||
return self._get_beorg_secret('machine_uid','BEORG_MACHINE_UID')
|
||||
|
||||
|
||||
@property
|
||||
def BEORG_TOKEN(self) -> str:
|
||||
return self._get_beorg_secret('token','BEORG_TOKEN')
|
||||
|
||||
|
||||
def _get_beorg_secret(self,*keys: str) -> str:
|
||||
for key in keys:
|
||||
value = self._vault_beorg_secrets.get(key)
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return ''
|
||||
|
||||
|
||||
def model_post_init(self,__context: Any) -> None:
|
||||
if not self.VAULT_ROLE_ID or not self.VAULT_SECRET_ID:
|
||||
return
|
||||
|
||||
client = VaultClient(
|
||||
addr=self.VAULT_ADDR,
|
||||
role_id=self.VAULT_ROLE_ID,
|
||||
secret_id=self.VAULT_SECRET_ID,
|
||||
namespace=self.VAULT_NAMESPACE,
|
||||
mount_point=self.VAULT_MOUNT_POINT,
|
||||
)
|
||||
object.__setattr__(self,'_vault_beorg_secrets',client.read_many(self.VAULT_BEORG_SECRET_PATH))
|
||||
secrets = client.read_many(
|
||||
self.VAULT_APP_SECRET_PATH,
|
||||
self.VAULT_BEORG_SECRET_PATH,
|
||||
self.VAULT_DATABASE_SECRET_PATH,
|
||||
self.VAULT_JWT_SECRET_PATH,
|
||||
self.VAULT_RABBIT_SECRET_PATH,
|
||||
self.VAULT_DOCS_SECRET_PATH,
|
||||
)
|
||||
for field in type(self).model_fields:
|
||||
if field.startswith('VAULT_') or field == 'KEYDB_URL':
|
||||
continue
|
||||
value = secrets.get(field,secrets.get(field.lower()))
|
||||
if value is None:
|
||||
continue
|
||||
if field == 'BEORG_PROCESS_INFO' and isinstance(value,str):
|
||||
value = json.loads(value)
|
||||
object.__setattr__(self,field,value)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
4
src/infrastructure/context_vars/__init__.py
Normal file
4
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id',default='N/A')
|
||||
0
src/infrastructure/context_vars/trace_id.py
Normal file
0
src/infrastructure/context_vars/trace_id.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
0
src/infrastructure/database/decorators/__init__.py
Normal file
0
src/infrastructure/database/decorators/__init__.py
Normal file
5
src/infrastructure/database/models/__init__.py
Normal file
5
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.kyc import KycModel
|
||||
|
||||
__all__ = ['Base','UserModel','KycModel']
|
||||
5
src/infrastructure/database/models/base.py
Normal file
5
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
24
src/infrastructure/database/models/kyc.py
Normal file
24
src/infrastructure/database/models/kyc.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from sqlalchemy import Boolean,DateTime,ForeignKey,JSON,String,Text
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import AuditTimestampsMixin,SoftDeleteMixin,UlidPrimaryKeyMixin
|
||||
|
||||
|
||||
class KycModel(Base,UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin):
|
||||
__tablename__ = 'kyc'
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String(26),ForeignKey('users.id',ondelete='CASCADE'),nullable=False,index=True)
|
||||
user_token: Mapped[str | None] = mapped_column(String(255),nullable=True,index=True)
|
||||
client_user_token: Mapped[str | None] = mapped_column(String(255),nullable=True,index=True)
|
||||
link: Mapped[str | None] = mapped_column(Text,nullable=True)
|
||||
qr_code: Mapped[str | None] = mapped_column(Text,nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32),nullable=False,server_default='started',default='started',index=True)
|
||||
done_state: Mapped[bool | None] = mapped_column(Boolean,nullable=True)
|
||||
set_id: Mapped[str | None] = mapped_column(String(255),nullable=True,index=True)
|
||||
error: Mapped[str | None] = mapped_column(Text,nullable=True)
|
||||
result_data: Mapped[Any | None] = mapped_column(JSON,nullable=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True),nullable=False,index=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True),nullable=True)
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
12
src/infrastructure/database/models/mixins/audit.py
Normal file
12
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime,timezone
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True),default=_utc_now,nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True),default=_utc_now,onupdate=_utc_now,nullable=False)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean,nullable=False,server_default='false',default=False)
|
||||
7
src/infrastructure/database/models/mixins/ulid.py
Normal file
7
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
id: Mapped[str] = mapped_column(String(26),primary_key=True,default=lambda: str(ULID()))
|
||||
29
src/infrastructure/database/models/user.py
Normal file
29
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import AuditTimestampsMixin, SoftDeleteMixin, UlidPrimaryKeyMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
|
||||
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
2
src/infrastructure/database/repositories/__init__.py
Normal file
2
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.database.repositories.kyc_repository import KycRepository
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
120
src/infrastructure/database/repositories/kyc_repository.py
Normal file
120
src/infrastructure/database/repositories/kyc_repository.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime,timezone
|
||||
from typing import Any
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.abstractions.repositories import IKycRepository
|
||||
from src.application.domain.entities import KycEntity
|
||||
from src.infrastructure.database.models.kyc import KycModel
|
||||
|
||||
|
||||
class KycRepository(IKycRepository):
|
||||
|
||||
def __init__(self,session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
|
||||
async def create_started_session(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_token: str | None,
|
||||
client_user_token: str | None,
|
||||
link: str | None,
|
||||
qr_code: str | None,
|
||||
expires_at: datetime,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
kyc = KycModel(
|
||||
user_id=user_id,
|
||||
user_token=user_token,
|
||||
client_user_token=client_user_token,
|
||||
link=link,
|
||||
qr_code=qr_code,
|
||||
status='started',
|
||||
expires_at=expires_at,
|
||||
error=error,
|
||||
)
|
||||
self._session.add(kyc)
|
||||
await self._session.flush()
|
||||
|
||||
|
||||
async def update_session_result(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
user_token: str,
|
||||
status: str,
|
||||
done_state: bool | None,
|
||||
set_id: str | None,
|
||||
result_data: Any,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
result = await self._session.execute(
|
||||
select(KycModel)
|
||||
.where(KycModel.user_id == user_id,KycModel.user_token == user_token)
|
||||
.order_by(KycModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
kyc = result.scalar_one_or_none()
|
||||
if kyc is None:
|
||||
return
|
||||
|
||||
kyc.status = status
|
||||
kyc.done_state = done_state
|
||||
kyc.set_id = set_id
|
||||
kyc.result_data = result_data
|
||||
kyc.error = error
|
||||
kyc.completed_at = datetime.now(timezone.utc)
|
||||
await self._session.flush()
|
||||
|
||||
|
||||
async def expire_started_sessions(self,*,user_id: str,now: datetime) -> None:
|
||||
result = await self._session.execute(
|
||||
select(KycModel)
|
||||
.where(
|
||||
KycModel.user_id == user_id,
|
||||
KycModel.status == 'started',
|
||||
KycModel.expires_at <= now,
|
||||
)
|
||||
)
|
||||
for kyc in result.scalars():
|
||||
kyc.status = 'expired'
|
||||
await self._session.flush()
|
||||
|
||||
|
||||
async def get_active_session(self,*,user_id: str,now: datetime) -> KycEntity | None:
|
||||
result = await self._session.execute(
|
||||
select(KycModel)
|
||||
.where(
|
||||
KycModel.user_id == user_id,
|
||||
KycModel.status == 'started',
|
||||
KycModel.expires_at > now,
|
||||
)
|
||||
.order_by(KycModel.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
kyc = result.scalar_one_or_none()
|
||||
if kyc is None:
|
||||
return None
|
||||
return self._to_entity(kyc)
|
||||
|
||||
|
||||
def _to_entity(self,kyc: KycModel) -> KycEntity:
|
||||
return KycEntity(
|
||||
id=kyc.id,
|
||||
user_id=kyc.user_id,
|
||||
user_token=kyc.user_token,
|
||||
client_user_token=kyc.client_user_token,
|
||||
link=kyc.link,
|
||||
qr_code=kyc.qr_code,
|
||||
status=kyc.status,
|
||||
done_state=kyc.done_state,
|
||||
set_id=kyc.set_id,
|
||||
error=kyc.error,
|
||||
result_data=kyc.result_data,
|
||||
expires_at=kyc.expires_at,
|
||||
completed_at=kyc.completed_at,
|
||||
created_at=kyc.created_at,
|
||||
updated_at=kyc.updated_at,
|
||||
)
|
||||
82
src/infrastructure/database/repositories/user_repository.py
Normal file
82
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
from datetime import date,datetime,timezone
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
|
||||
def __init__(self,session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
|
||||
async def create_user(self,email: str,password_hash: str) -> UserEntity:
|
||||
user = UserModel(email=email,password_hash=password_hash)
|
||||
self._session.add(user)
|
||||
await self._session.flush()
|
||||
return self._to_entity(user)
|
||||
|
||||
|
||||
async def get_user_by_email(self,email: str) -> UserEntity:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise ApplicationException(status_code=404,message='User not found')
|
||||
return self._to_entity(user)
|
||||
|
||||
|
||||
async def exists_by_email(self,email: str) -> bool:
|
||||
result = await self._session.execute(select(UserModel.id).where(UserModel.email == email))
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
|
||||
async def update_kyc_data(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
birth_date: date,
|
||||
middle_name: str | None,
|
||||
inn: str | None,
|
||||
) -> UserEntity:
|
||||
user = await self._session.get(UserModel,user_id)
|
||||
if user is None:
|
||||
raise ApplicationException(status_code=404,message='User not found')
|
||||
|
||||
user.first_name = first_name
|
||||
user.last_name = last_name
|
||||
user.middle_name = middle_name
|
||||
user.birth_date = birth_date
|
||||
user.inn = inn
|
||||
user.kyc_verified = True
|
||||
user.kyc_verified_at = datetime.now(timezone.utc)
|
||||
await self._session.flush()
|
||||
return self._to_entity(user)
|
||||
|
||||
|
||||
def _to_entity(self,user: UserModel) -> UserEntity:
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
)
|
||||
62
src/infrastructure/database/unit_of_work.py
Normal file
62
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession,async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IKycRepository,IUserRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import KycRepository,UserRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self,session_factory: async_sessionmaker[AsyncSession],logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession | None = None
|
||||
self._user_repository: IUserRepository | None = None
|
||||
self._kyc_repository: IKycRepository | None = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
self._user_repository = None
|
||||
self._kyc_repository = None
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._session is None:
|
||||
return
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
|
||||
async def commit(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.commit()
|
||||
|
||||
|
||||
async def rollback(self) -> None:
|
||||
if self._session is not None:
|
||||
await self._session.rollback()
|
||||
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._session is None:
|
||||
raise RuntimeError('UnitOfWork session is not initialized')
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session)
|
||||
return self._user_repository
|
||||
|
||||
|
||||
@property
|
||||
def kyc_repository(self) -> IKycRepository:
|
||||
if self._session is None:
|
||||
raise RuntimeError('UnitOfWork session is not initialized')
|
||||
if self._kyc_repository is None:
|
||||
self._kyc_repository = KycRepository(session=self._session)
|
||||
return self._kyc_repository
|
||||
4
src/infrastructure/logger/__init__.py
Normal file
4
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
|
||||
logger = Logger()
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import orjson
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = orjson.dumps(log_data).decode()
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
1
src/infrastructure/messanger/__init__.py
Normal file
1
src/infrastructure/messanger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.messanger.rabbit_client import RabbitClient
|
||||
72
src/infrastructure/messanger/rabbit_client.py
Normal file
72
src/infrastructure/messanger/rabbit_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, Mapping
|
||||
from faststream.rabbit import RabbitBroker
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
class RabbitClient(IQueueMessanger):
|
||||
def __init__(self) -> None:
|
||||
self._broker = RabbitBroker(
|
||||
settings.RABBIT_URL,
|
||||
)
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._connected:
|
||||
return
|
||||
await self._broker.connect()
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._connected:
|
||||
return
|
||||
await self._broker.close()
|
||||
self._connected = False
|
||||
|
||||
async def _ensure_connected(self) -> None:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
queue=queue,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
exchange=exchange,
|
||||
routing_key=routing_key,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
0
src/infrastructure/security/__init__.py
Normal file
0
src/infrastructure/security/__init__.py
Normal file
0
src/infrastructure/security/csrf.py
Normal file
0
src/infrastructure/security/csrf.py
Normal file
0
src/infrastructure/security/hash.py
Normal file
0
src/infrastructure/security/hash.py
Normal file
109
src/infrastructure/security/jwt.py
Normal file
109
src/infrastructure/security/jwt.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
class JwtService(IJwtService):
|
||||
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
|
||||
self._logger = logger
|
||||
self._key_store = key_store
|
||||
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'access':
|
||||
self._logger.warning(f'Access token invalid type received_type={payload.get("type")}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return AccessTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
sid=str(payload['sid']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Access token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def _decode_and_verify(self, token: str) -> dict:
|
||||
kid: str | None = None
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
kid = header.get('kid')
|
||||
if not kid:
|
||||
self._logger.warning(f'JWT header missing kid header={header}')
|
||||
raise ApplicationException(status_code=401, message='Missing token header: kid')
|
||||
|
||||
received_alg = header.get('alg')
|
||||
if received_alg != settings.JWT_ALGORITHM:
|
||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token algorithm')
|
||||
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
|
||||
await self._key_store.refresh()
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Unknown token kid')
|
||||
|
||||
options = {
|
||||
'verify_signature': True,
|
||||
'verify_exp': True,
|
||||
'verify_nbf': True,
|
||||
'verify_iat': True,
|
||||
'verify_aud': bool(settings.JWT_AUDIENCE),
|
||||
'verify_iss': bool(settings.JWT_ISSUER),
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'require_sub': True,
|
||||
'leeway': 10,
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE or None,
|
||||
issuer=settings.JWT_ISSUER or None,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if 'sid' not in payload:
|
||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: sid')
|
||||
|
||||
if 'type' not in payload:
|
||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError as exception:
|
||||
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except JWTError as exception:
|
||||
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token')
|
||||
|
||||
except Exception as exception:
|
||||
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT decode failed')
|
||||
5
src/infrastructure/utils/__init__.py
Normal file
5
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
return str(ULID())
|
||||
0
src/infrastructure/utils/instance_id.py
Normal file
0
src/infrastructure/utils/instance_id.py
Normal file
101
src/infrastructure/vault/__init__.py
Normal file
101
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
import hvac
|
||||
|
||||
|
||||
class VaultClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
addr: str,
|
||||
role_id: str,
|
||||
secret_id: str,
|
||||
namespace: str | None,
|
||||
mount_point: str,
|
||||
) -> None:
|
||||
self._mount_point = mount_point
|
||||
self._client = hvac.Client(url=addr,namespace=namespace)
|
||||
self._client.auth.approle.login(role_id=role_id,secret_id=secret_id)
|
||||
|
||||
|
||||
def read_secret(self,path: str) -> dict[str,Any]:
|
||||
secret = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount_point,
|
||||
)
|
||||
return dict(secret.get('data',{}).get('data',{}))
|
||||
|
||||
|
||||
def read_many(self,*paths: str) -> dict[str,Any]:
|
||||
result: dict[str,Any] = {}
|
||||
for path in paths:
|
||||
if not path:
|
||||
continue
|
||||
try:
|
||||
result.update(self.read_secret(path))
|
||||
except hvac.exceptions.InvalidPath:
|
||||
continue
|
||||
return result
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
_instance: 'JwtKeyStore | None' = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str = '',
|
||||
vault_role_id: str = '',
|
||||
vault_secret_id: str = '',
|
||||
vault_namespace: str | None = None,
|
||||
mount_point: str = '',
|
||||
kid_path: str = '',
|
||||
kids_prefix: str = '',
|
||||
) -> None:
|
||||
self._keys: dict[str,str] = {}
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
self._vault_client: VaultClient | None = None
|
||||
if vault_addr and vault_role_id and vault_secret_id:
|
||||
self._vault_client = VaultClient(
|
||||
addr=vault_addr,
|
||||
role_id=vault_role_id,
|
||||
secret_id=vault_secret_id,
|
||||
namespace=vault_namespace,
|
||||
mount_point=mount_point,
|
||||
)
|
||||
JwtKeyStore._instance = self
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
|
||||
async def refresh(self) -> None:
|
||||
if self._vault_client is None:
|
||||
return None
|
||||
|
||||
current = self._vault_client.read_secret(self._kid_path)
|
||||
kid = current.get('kid')
|
||||
if kid:
|
||||
key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
|
||||
public_key = key_data.get('public_key') or key_data.get('public')
|
||||
if public_key:
|
||||
self._keys[str(kid)] = str(public_key)
|
||||
return None
|
||||
|
||||
|
||||
async def get_public_key_for_kid(self,kid: str) -> str | None:
|
||||
return self._keys.get(kid)
|
||||
|
||||
|
||||
def start_jwt_keys_scheduler(jwt_store: JwtKeyStore,refresh_seconds: int) -> AsyncIOScheduler:
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(jwt_store.refresh,'interval',seconds=refresh_seconds)
|
||||
scheduler.start()
|
||||
return scheduler
|
||||
0
src/infrastructure/vault/keys.py
Normal file
0
src/infrastructure/vault/keys.py
Normal file
0
src/infrastructure/vault/scheduler.py
Normal file
0
src/infrastructure/vault/scheduler.py
Normal file
0
src/infrastructure/vault/utils.py
Normal file
0
src/infrastructure/vault/utils.py
Normal file
119
src/main.py
Normal file
119
src/main.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI, status
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.cache import create_redis_client
|
||||
from src.infrastructure.config.settings import get_settings
|
||||
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
|
||||
from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
|
||||
from src.presentation.messaging import crypto_transfer_router
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
from src.presentation.routing import kyc_router
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
|
||||
if not (user_ok and pass_ok):
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message='Unauthorized',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
instance_id = generate_instance_id()
|
||||
logger.set_instance_id(instance_id)
|
||||
logger.info(f'Users service instance started with id {instance_id}')
|
||||
|
||||
app.state.redis = create_redis_client(settings.KEYDB_URL)
|
||||
|
||||
jwt_store = JwtKeyStore(
|
||||
vault_addr=settings.VAULT_ADDR,
|
||||
vault_role_id=settings.VAULT_ROLE_ID,
|
||||
vault_secret_id=settings.VAULT_SECRET_ID,
|
||||
vault_namespace=settings.VAULT_NAMESPACE,
|
||||
mount_point=settings.VAULT_MOUNT_POINT,
|
||||
kid_path=settings.VAULT_JWT_KID_PATH,
|
||||
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
||||
)
|
||||
|
||||
await jwt_store.refresh()
|
||||
|
||||
jwt_scheduler = start_jwt_keys_scheduler(jwt_store, refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS)
|
||||
|
||||
app.state.jwt_key_store = jwt_store
|
||||
app.state.jwt_keys_scheduler = jwt_scheduler
|
||||
yield
|
||||
await app.state.redis.aclose()
|
||||
logger.info(f'Users service instance ended with id {instance_id}')
|
||||
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=None,
|
||||
lifespan=lifespan,
|
||||
title='Elcsa Users Service'
|
||||
)
|
||||
|
||||
app.add_exception_handler(ApplicationException, application_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
app.include_router(kyc_router)
|
||||
app.include_router(crypto_transfer_router)
|
||||
|
||||
|
||||
# Added middleware
|
||||
app.add_middleware(TraceIDMiddleware, logger=logger)
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts=True,
|
||||
hsts_preload=False,
|
||||
frame_options='DENY',
|
||||
referrer_policy='strict-origin-when-cross-origin',
|
||||
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
|
||||
)
|
||||
|
||||
|
||||
@app.get('/docs', include_in_schema=False)
|
||||
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom Swagger documentation, optionally protected with basic authentication.'''
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
|
||||
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
|
||||
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
|
||||
)
|
||||
|
||||
|
||||
@app.get('/redoc', include_in_schema=False)
|
||||
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
|
||||
return get_redoc_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
|
||||
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
|
||||
)
|
||||
|
||||
|
||||
@app.post('/ping')
|
||||
async def ping() -> dict[str, str]:
|
||||
return {
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
1
src/presentation/decorators/__init__.py
Normal file
1
src/presentation/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.decorators.auth import require_access_token
|
||||
36
src/presentation/decorators/auth.py
Normal file
36
src/presentation/decorators/auth.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from src.application.contracts import IJwtService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.domain.dto import AccessTokenPayload, AuthContext
|
||||
from src.presentation.dependencies import get_jwt_service
|
||||
|
||||
|
||||
def _extract_access_token(request: Request) -> str | None:
|
||||
token = request.cookies.get('access_token')
|
||||
|
||||
if token:
|
||||
return token
|
||||
|
||||
auth = request.headers.get('Authorization')
|
||||
if auth:
|
||||
scheme, param = get_authorization_scheme_param(auth)
|
||||
if scheme.lower() == 'bearer' and param:
|
||||
return param
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
request: Request,
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
) -> AuthContext:
|
||||
token = _extract_access_token(request)
|
||||
if not token:
|
||||
raise ApplicationException(status_code=401, message='Not authenticated')
|
||||
|
||||
payload: AccessTokenPayload = await jwt_service.decode_access_token(token)
|
||||
if payload.type != 'access':
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)
|
||||
45
src/presentation/decorators/cache.py
Normal file
45
src/presentation/decorators/cache.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Any,Awaitable,Callable
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.presentation.dependencies.cache import get_cache_remote
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
def cached(*, prefix: str) -> Callable:
|
||||
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
logger = get_logger()
|
||||
|
||||
request = kwargs.get('request')
|
||||
if not isinstance(request, Request):
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
request = a
|
||||
break
|
||||
|
||||
auth = kwargs.get('auth')
|
||||
user_id = getattr(auth, 'user_id', None) if auth else None
|
||||
|
||||
if request is None or user_id is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
cache_key = f'{prefix}:{user_id}'
|
||||
|
||||
try:
|
||||
cache = get_cache_remote(request)
|
||||
hit = await cache.get_user(user_id)
|
||||
if hit is not None:
|
||||
logger.debug(f'Cache hit key={cache_key}')
|
||||
return ORJSONResponse(status_code=200, content=hit)
|
||||
except Exception as e:
|
||||
logger.warning(f'Cache read failed key={cache_key} error={e}')
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
61
src/presentation/decorators/csrf.py
Normal file
61
src/presentation/decorators/csrf.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, Any, Optional, Annotated
|
||||
from fastapi import Request, Header
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.security import CsrfService
|
||||
|
||||
|
||||
def csrf_protect(
|
||||
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
|
||||
):
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
|
||||
if not has_request:
|
||||
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
|
||||
|
||||
has_header = any(p.name == 'x_csrf_token' for p in params)
|
||||
if not has_header:
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
name='x_csrf_token',
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
|
||||
)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request | None = kwargs.get('request')
|
||||
if request is None:
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
raise ApplicationException(
|
||||
status_code=500,
|
||||
message='Request is required for CSRF protection',
|
||||
)
|
||||
|
||||
csrf = CsrfService()
|
||||
|
||||
cookie_token, _ = csrf.extract(request.cookies, request.headers)
|
||||
header_token = kwargs.get('x_csrf_token')
|
||||
|
||||
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
|
||||
csrf.verify_pair(cookie_token, header_token, expected_subject)
|
||||
|
||||
kwargs.pop('x_csrf_token', None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=params)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
171
src/presentation/decorators/rate_limit.py
Normal file
171
src/presentation/decorators/rate_limit.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import hashlib
|
||||
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
|
||||
from fastapi import Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies import get_redis
|
||||
|
||||
|
||||
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
|
||||
req = kwargs.get('request')
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
return a
|
||||
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
if xff:
|
||||
return xff.split(',')[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return 'unknown'
|
||||
|
||||
|
||||
_LUA_INCR_EXPIRE_TTL = '''
|
||||
local key = KEYS[1]
|
||||
local window = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, window)
|
||||
end
|
||||
|
||||
local ttl = redis.call('TTL', key)
|
||||
return { current, ttl }
|
||||
'''
|
||||
|
||||
|
||||
Scope = Literal['ip', 'device', 'user', 'key']
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder1(Protocol):
|
||||
def __call__(self, request: Request) -> str: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder3(Protocol):
|
||||
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
KeyBuilder = KeyBuilder1 | KeyBuilder3
|
||||
|
||||
|
||||
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
try:
|
||||
sig = inspect.signature(builder)
|
||||
if len(sig.parameters) >= 3:
|
||||
return builder(request, args, kwargs)
|
||||
return builder(request)
|
||||
except Exception as e:
|
||||
try:
|
||||
return builder(request, args, kwargs)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
|
||||
body = kwargs.get('body')
|
||||
if body is None and args:
|
||||
for a in args:
|
||||
if hasattr(a, 'email'):
|
||||
body = a
|
||||
break
|
||||
|
||||
email = (getattr(body, 'email', '') or '').strip().lower()
|
||||
if not email:
|
||||
email = _client_ip(request)
|
||||
|
||||
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
|
||||
return f'email:{digest}'
|
||||
|
||||
def rate_limit(
|
||||
*,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
scope: Scope = 'ip',
|
||||
key_prefix: str = 'rl',
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
fail_open: bool = True,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
||||
|
||||
if limit <= 0:
|
||||
raise ValueError('rate_limit: limit must be > 0')
|
||||
if window_seconds <= 0:
|
||||
raise ValueError('rate_limit: window_seconds must be > 0')
|
||||
if scope == 'key' and not key_builder:
|
||||
raise ValueError('rate_limit: scope="key" requires key_builder')
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any):
|
||||
request = _find_request(args, kwargs)
|
||||
logger: ILogger = get_logger()
|
||||
|
||||
if scope == 'ip':
|
||||
ident = _client_ip(request)
|
||||
elif scope == 'device':
|
||||
ident = request.cookies.get('device_id') or _client_ip(request)
|
||||
elif scope == 'user':
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = getattr(user, 'id', None) if user else None
|
||||
ident = str(user_id) if user_id else _client_ip(request)
|
||||
else:
|
||||
try:
|
||||
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit key_builder failed error={str(e)}')
|
||||
raise ApplicationException(500, 'Rate limiter key_builder failed')
|
||||
|
||||
route = request.url.path
|
||||
method = request.method
|
||||
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
|
||||
|
||||
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
|
||||
|
||||
try:
|
||||
redis: Redis = get_redis(request)
|
||||
|
||||
result = await redis.eval(
|
||||
_LUA_INCR_EXPIRE_TTL,
|
||||
1,
|
||||
redis_key,
|
||||
str(window_seconds),
|
||||
)
|
||||
|
||||
count = int(result[0])
|
||||
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
|
||||
ttl = window_seconds if ttl_raw < 0 else ttl_raw
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
|
||||
|
||||
if fail_open:
|
||||
logger.warning(f'RateLimit fail-open activated key={redis_key}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
raise ApplicationException(503, 'Rate limiter unavailable')
|
||||
|
||||
if count > limit:
|
||||
retry_after = max(ttl, 0)
|
||||
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
|
||||
raise ApplicationException(
|
||||
status_code=429,
|
||||
message='Too Many Requests',
|
||||
headers={'Retry-After': str(retry_after)},
|
||||
)
|
||||
|
||||
logger.debug(f'RateLimit passed key={redis_key} count={count}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
1
src/presentation/dependencies/__init__.py
Normal file
1
src/presentation/dependencies/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.dependencies.security import get_jwt_service
|
||||
28
src/presentation/dependencies/cache.py
Normal file
28
src/presentation/dependencies/cache.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from fastapi import Depends,Request
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from src.application.contracts import ICache
|
||||
from src.infrastructure.cache import KeydbCache
|
||||
|
||||
|
||||
def get_redis_remote(request: Request) -> Redis:
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
def get_redis(request: Request) -> Redis:
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
def get_cache_remote(redis_client: Redis = Depends(get_redis_remote)) -> ICache:
|
||||
return KeydbCache(redis_client)
|
||||
|
||||
|
||||
def get_remote_cache(redis_client: Redis = Depends(get_redis_remote)) -> ICache:
|
||||
return KeydbCache(redis_client)
|
||||
|
||||
|
||||
def get_cache(redis_client: Redis = Depends(get_redis)) -> ICache:
|
||||
return KeydbCache(redis_client)
|
||||
60
src/presentation/dependencies/commands.py
Normal file
60
src/presentation/dependencies/commands.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands import CompleteKycCommand,GetKycSessionCommand,PassKycCommand
|
||||
from src.application.contracts import IBeorgService,ICache,ILogger,IQueueMessanger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.beorg import BeorgService
|
||||
from src.presentation.dependencies.cache import get_cache
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.dependencies.queue_messanger import get_rabbit
|
||||
from src.presentation.dependencies.unit_of_work import get_unit_of_work
|
||||
|
||||
|
||||
def get_beorg_service() -> IBeorgService:
|
||||
return BeorgService(
|
||||
project_id=settings.BEORG_PROJECT_ID,
|
||||
machine_uid=settings.BEORG_MACHINE_UID,
|
||||
token=settings.BEORG_TOKEN,
|
||||
process_info=settings.BEORG_PROCESS_INFO,
|
||||
timeout=settings.BEORG_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
def get_pass_kyc_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
beorg_service: IBeorgService = Depends(get_beorg_service),
|
||||
) -> PassKycCommand:
|
||||
return PassKycCommand(
|
||||
unit_of_work=unit_of_work,
|
||||
logger=logger,
|
||||
cache=cache,
|
||||
beorg_service=beorg_service,
|
||||
)
|
||||
|
||||
|
||||
def get_kyc_session_command(
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
) -> GetKycSessionCommand:
|
||||
return GetKycSessionCommand(
|
||||
unit_of_work=unit_of_work,
|
||||
)
|
||||
|
||||
|
||||
def get_complete_kyc_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
beorg_service: IBeorgService = Depends(get_beorg_service),
|
||||
queue_messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
) -> CompleteKycCommand:
|
||||
return CompleteKycCommand(
|
||||
unit_of_work=unit_of_work,
|
||||
logger=logger,
|
||||
cache=cache,
|
||||
beorg_service=beorg_service,
|
||||
queue_messanger=queue_messanger,
|
||||
verified_queue=settings.RABBIT_KYC_VERIFIED_QUEUE,
|
||||
)
|
||||
7
src/presentation/dependencies/logger.py
Normal file
7
src/presentation/dependencies/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
@lru_cache
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
8
src/presentation/dependencies/queue_messanger.py
Normal file
8
src/presentation/dependencies/queue_messanger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.messanger import RabbitClient
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_rabbit() -> IQueueMessanger:
|
||||
return RabbitClient()
|
||||
16
src/presentation/dependencies/security.py
Normal file
16
src/presentation/dependencies/security.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from functools import lru_cache
|
||||
from fastapi import Depends
|
||||
from src.application.contracts import IJwtService,ILogger
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _jwt_service(logger: ILogger) -> IJwtService:
|
||||
key_store = JwtKeyStore.get_instance()
|
||||
return JwtService(logger=logger, key_store=key_store)
|
||||
|
||||
|
||||
def get_jwt_service(logger: ILogger = Depends(get_logger)) -> IJwtService:
|
||||
return _jwt_service(logger)
|
||||
10
src/presentation/dependencies/unit_of_work.py
Normal file
10
src/presentation/dependencies/unit_of_work.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.context import async_session_maker
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
def get_unit_of_work(logger: ILogger = Depends(get_logger)) -> IUnitOfWork:
|
||||
return UnitOfWork(session_factory=async_session_maker,logger=logger)
|
||||
18
src/presentation/handlers/__init__.py
Normal file
18
src/presentation/handlers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
async def application_exception_handler(request: Request,exception: ApplicationException) -> ORJSONResponse:
|
||||
return ORJSONResponse(
|
||||
status_code=exception.status_code,
|
||||
content={'detail': exception.message},
|
||||
headers=exception.headers,
|
||||
)
|
||||
|
||||
|
||||
async def unhandled_exception_handler(request: Request,exception: Exception) -> ORJSONResponse:
|
||||
return ORJSONResponse(
|
||||
status_code=500,
|
||||
content={'detail': 'Internal server error'},
|
||||
)
|
||||
0
src/presentation/handlers/application_handler.py
Normal file
0
src/presentation/handlers/application_handler.py
Normal file
0
src/presentation/handlers/unhandled_handler.py
Normal file
0
src/presentation/handlers/unhandled_handler.py
Normal file
1
src/presentation/messaging/__init__.py
Normal file
1
src/presentation/messaging/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.messaging.crypto_transfer import crypto_transfer_router
|
||||
39
src/presentation/messaging/crypto_transfer.py
Normal file
39
src/presentation/messaging/crypto_transfer.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from fastapi import Depends
|
||||
import orjson
|
||||
from faststream.rabbit.fastapi import RabbitMessage,RabbitRouter
|
||||
from pydantic import BaseModel
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
crypto_transfer_router=RabbitRouter(settings.RABBIT_URL)
|
||||
|
||||
|
||||
class CryptoTransferCompletedMessage(BaseModel):
|
||||
user_id: str
|
||||
order_id: str
|
||||
trace_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
@crypto_transfer_router.subscriber(settings.RABBIT_CRYPTO_TRANSFER_COMPLETED_QUEUE)
|
||||
async def crypto_transfer_completed_handler(
|
||||
msg_body: CryptoTransferCompletedMessage,
|
||||
message: RabbitMessage,
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> None:
|
||||
trace_id=msg_body.trace_id
|
||||
token=trace_id_var.set(trace_id)
|
||||
try:
|
||||
payload=msg_body.model_dump(mode='json')
|
||||
logger.info(orjson.dumps({
|
||||
'event':'crypto_transfer_completed_received',
|
||||
'payload':payload,
|
||||
'rabbit_message_id':message.message_id,
|
||||
'rabbit_correlation_id':message.correlation_id,
|
||||
},default=str).decode())
|
||||
finally:
|
||||
trace_id_var.reset(token)
|
||||
|
||||
2
src/presentation/middleware/__init__.py
Normal file
2
src/presentation/middleware/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.middleware.trace_id import TraceIDMiddleware
|
||||
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware
|
||||
41
src/presentation/middleware/security_headers.py
Normal file
41
src/presentation/middleware/security_headers.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from starlette.types import ASGIApp,Message,Receive,Scope,Send
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
*,
|
||||
hsts: bool = True,
|
||||
hsts_preload: bool = False,
|
||||
frame_options: str = 'DENY',
|
||||
referrer_policy: str = 'strict-origin-when-cross-origin',
|
||||
content_security_policy: str | None = None,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self._headers = {
|
||||
'x-content-type-options': 'nosniff',
|
||||
'x-frame-options': frame_options,
|
||||
'referrer-policy': referrer_policy,
|
||||
}
|
||||
if hsts:
|
||||
value = 'max-age=31536000; includeSubDomains'
|
||||
if hsts_preload:
|
||||
value = f'{value}; preload'
|
||||
self._headers['strict-transport-security'] = value
|
||||
if content_security_policy:
|
||||
self._headers['content-security-policy'] = content_security_policy
|
||||
|
||||
|
||||
async def __call__(self,scope: Scope,receive: Receive,send: Send) -> None:
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message['type'] == 'http.response.start':
|
||||
headers = list(message.get('headers',[]))
|
||||
for key,value in self._headers.items():
|
||||
headers.append((key.encode(),value.encode()))
|
||||
message['headers'] = headers
|
||||
await send(message)
|
||||
|
||||
await self.app(scope,receive,send_wrapper)
|
||||
135
src/presentation/middleware/trace_id.py
Normal file
135
src/presentation/middleware/trace_id.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from contextvars import Token
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class TraceIDMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
logger: ILogger,
|
||||
response_header_name: str = "X-Trace-ID",
|
||||
attach_response_header: bool = True,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.response_header_name = response_header_name
|
||||
self.attach_response_header = attach_response_header
|
||||
|
||||
def _is_excluded(self, path: str) -> bool:
|
||||
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope)
|
||||
|
||||
if self._is_excluded(request.url.path):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
|
||||
if not trace_id:
|
||||
trace_id = str(ULID())
|
||||
|
||||
request.state.trace_id = trace_id
|
||||
|
||||
token: Token = trace_id_var.set(trace_id)
|
||||
|
||||
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
|
||||
|
||||
status_code_holder: dict[str, Optional[int]] = {"status": None}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
status_code_holder["status"] = int(message["status"])
|
||||
|
||||
if self.attach_response_header:
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
status = status_code_holder["status"]
|
||||
status_part = f"{status}" if status is not None else "unknown"
|
||||
self.logger.debug(
|
||||
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
|
||||
)
|
||||
trace_id_var.reset(token)
|
||||
|
||||
|
||||
# from __future__ import annotations
|
||||
# from typing import Optional
|
||||
# from starlette.requests import Request
|
||||
# from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
# from ulid import ULID
|
||||
# from src.application.contracts import ILogger
|
||||
# from src.infrastructure.config.settings import settings
|
||||
#
|
||||
#
|
||||
# class TraceIDMiddleware:
|
||||
# def __init__(
|
||||
# self,
|
||||
# app: ASGIApp,
|
||||
# logger: ILogger,
|
||||
# response_header_name: str = 'X-Trace-ID',
|
||||
# attach_response_header: bool = True,
|
||||
# ) -> None:
|
||||
# self.app = app
|
||||
# self.logger = logger
|
||||
# self.response_header_name = response_header_name
|
||||
# self.attach_response_header = attach_response_header
|
||||
#
|
||||
# def _is_excluded(self, path: str) -> bool:
|
||||
# return any(path == p or path.startswith(p.rstrip('/') + '/') for p in settings.EXCLUDED_PATHS)
|
||||
#
|
||||
# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# if scope['type'] != 'http':
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# request = Request(scope)
|
||||
#
|
||||
# if self._is_excluded(request.url.path):
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# trace_id = request.headers.get('X-Trace-ID') or request.headers.get('X-Request-ID')
|
||||
# if not trace_id:
|
||||
# trace_id = str(ULID())
|
||||
#
|
||||
# request.state.trace_id = trace_id
|
||||
# self.logger.set_trace_id(trace_id)
|
||||
#
|
||||
# self.logger.debug(f'Request started: {request.method} {request.url} - TraceID: {trace_id}')
|
||||
#
|
||||
# status_code_holder: dict[str, Optional[int]] = {'status': None}
|
||||
#
|
||||
# async def send_wrapper(message: Message) -> None:
|
||||
# if message['type'] == 'http.response.start':
|
||||
# status_code_holder['status'] = int(message['status'])
|
||||
#
|
||||
# if self.attach_response_header:
|
||||
# headers = list(message.get('headers', []))
|
||||
# headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
# message['headers'] = headers
|
||||
# await send(message)
|
||||
#
|
||||
# try:
|
||||
# await self.app(scope, receive, send_wrapper)
|
||||
# finally:
|
||||
# status = status_code_holder['status']
|
||||
# status_part = f'{status}' if status is not None else 'unknown'
|
||||
# self.logger.debug(f'Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}')
|
||||
# self.logger.clear_trace_id()
|
||||
1
src/presentation/routing/__init__.py
Normal file
1
src/presentation/routing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.routing.kyc import kyc_router
|
||||
37
src/presentation/routing/kyc.py
Normal file
37
src/presentation/routing/kyc.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from fastapi import APIRouter,Depends
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.application.commands import CompleteKycCommand,GetKycSessionCommand,PassKycCommand
|
||||
from src.application.domain.dto import AuthContext
|
||||
from src.presentation.decorators.auth import require_access_token
|
||||
from src.presentation.dependencies.commands import get_complete_kyc_command,get_kyc_session_command,get_pass_kyc_command
|
||||
|
||||
|
||||
kyc_router = APIRouter(prefix='/kyc', tags=['Kyc'])
|
||||
|
||||
|
||||
@kyc_router.post('/create')
|
||||
async def create_kyc(
|
||||
auth: AuthContext = Depends(require_access_token),
|
||||
command: PassKycCommand = Depends(get_pass_kyc_command),
|
||||
) -> ORJSONResponse:
|
||||
user_id = auth.user_id
|
||||
result = await command(user_id=user_id)
|
||||
return ORJSONResponse(result.model_dump())
|
||||
|
||||
|
||||
@kyc_router.get('/session')
|
||||
async def get_kyc_session(
|
||||
auth: AuthContext = Depends(require_access_token),
|
||||
command: GetKycSessionCommand = Depends(get_kyc_session_command),
|
||||
) -> ORJSONResponse:
|
||||
result = await command(user_id=auth.user_id)
|
||||
return ORJSONResponse(result.model_dump())
|
||||
|
||||
|
||||
@kyc_router.post('/complete')
|
||||
async def complete_kyc(
|
||||
auth: AuthContext = Depends(require_access_token),
|
||||
command: CompleteKycCommand = Depends(get_complete_kyc_command),
|
||||
) -> ORJSONResponse:
|
||||
result = await command(user_id=auth.user_id)
|
||||
return ORJSONResponse(result.model_dump())
|
||||
0
src/presentation/schemas/__init__.py
Normal file
0
src/presentation/schemas/__init__.py
Normal file
Reference in New Issue
Block a user