Initial commit
This commit is contained in:
6
src/application/contracts/__init__.py
Normal file
6
src/application/contracts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
from src.application.contracts.i_jwt_service import IJwtService
|
||||
from src.application.contracts.i_csrf_service import ICsrfService
|
||||
from src.application.contracts.i_cache import ICache
|
||||
from src.application.contracts.i_hash_service import IHashService
|
||||
from src.application.contracts.i_queue_messanger import IQueueMessanger
|
||||
30
src/application/contracts/i_cache.py
Normal file
30
src/application/contracts/i_cache.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class ICache(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
raise NotImplementedError
|
||||
26
src/application/contracts/i_csrf_service.py
Normal file
26
src/application/contracts/i_csrf_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
|
||||
class ICsrfService(ABC):
|
||||
@abstractmethod
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify_pair(
|
||||
self,
|
||||
cookie_token: Optional[str],
|
||||
header_token: Optional[str],
|
||||
expected_subject: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
12
src/application/contracts/i_hash_service.py
Normal file
12
src/application/contracts/i_hash_service.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IHashService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def hash(self, value: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
raise NotImplementedError
|
||||
10
src/application/contracts/i_jwt_service.py
Normal file
10
src/application/contracts/i_jwt_service.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
|
||||
|
||||
class IJwtService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
raise NotImplementedError
|
||||
68
src/application/contracts/i_logger.py
Normal file
68
src/application/contracts/i_logger.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Protocol, Optional, Callable
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
|
||||
|
||||
class ILogger(Protocol):
|
||||
"""Interface for synchronous logger with ContextVar support for trace_id."""
|
||||
|
||||
log_format: LogFormat
|
||||
min_level: LogLevel
|
||||
id_generator: Optional[Callable[[], str]]
|
||||
instance_id: str
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
"""Set log format using LogFormat enum"""
|
||||
...
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
"""Set minimum log level"""
|
||||
...
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
"""Create and set new trace_id in context"""
|
||||
...
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
"""Set existing trace_id in context"""
|
||||
...
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
"""Get current trace_id from context"""
|
||||
...
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
"""Clear the trace_id in the context"""
|
||||
...
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
"""Set service instance id (ULID recommended)"""
|
||||
...
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
"""Get current service instance id"""
|
||||
...
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log debug message"""
|
||||
...
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log info message"""
|
||||
...
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log warning message"""
|
||||
...
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log error message"""
|
||||
...
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Log critical message"""
|
||||
...
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
"""Log exception with traceback"""
|
||||
...
|
||||
40
src/application/contracts/i_queue_messanger.py
Normal file
40
src/application/contracts/i_queue_messanger.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
class IQueueMessanger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
2
src/application/domain/dto/__init__.py
Normal file
2
src/application/domain/dto/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.dto.token import AccessTokenPayload, AuthContext
|
||||
from src.application.domain.dto.keys import JwtPublicKey, JwtPublicKeySet
|
||||
20
src/application/domain/dto/keys.py
Normal file
20
src/application/domain/dto/keys.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKey:
|
||||
kid: str
|
||||
public_key_pem: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKeySet:
|
||||
active: JwtPublicKey
|
||||
previous: Optional[JwtPublicKey] = None
|
||||
|
||||
def public_keys_by_kid(self) -> Dict[str, str]:
|
||||
out = {self.active.kid: self.active.public_key_pem}
|
||||
if self.previous:
|
||||
out[self.previous.kid] = self.previous.public_key_pem
|
||||
return out
|
||||
18
src/application/domain/dto/token.py
Normal file
18
src/application/domain/dto/token.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AccessTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
user_id: str
|
||||
sid: str
|
||||
token: AccessTokenPayload
|
||||
5
src/application/domain/entities/__init__.py
Normal file
5
src/application/domain/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
from src.application.domain.entities.session import SessionEntity
|
||||
|
||||
|
||||
__all__ = ['UserEntity', 'SessionEntity']
|
||||
20
src/application/domain/entities/session.py
Normal file
20
src/application/domain/entities/session.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionEntity:
|
||||
sid: str
|
||||
user_id: str
|
||||
device_id: str
|
||||
|
||||
revoked_at: datetime | None
|
||||
last_seen_at: datetime
|
||||
|
||||
refresh_jti_hash: str | None
|
||||
refresh_expires_at: datetime | None
|
||||
|
||||
user_agent: str | None = None
|
||||
first_ip: str | None = None
|
||||
last_ip: str | None = None
|
||||
30
src/application/domain/entities/user.py
Normal file
30
src/application/domain/entities/user.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserEntity:
|
||||
id: str | None = None
|
||||
email: str | None = None
|
||||
password_hash: str | None = None
|
||||
|
||||
first_name: str | None = None
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
|
||||
crypto_wallet: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
inn: str | None = None
|
||||
|
||||
kyc_verified: bool | None = None
|
||||
is_deleted: bool | None = None
|
||||
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
2
src/application/domain/enums/__init__.py
Normal file
2
src/application/domain/enums/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
7
src/application/domain/enums/log_format.py
Normal file
7
src/application/domain/enums/log_format.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
"""Enum for supported log formats"""
|
||||
TEXT = 'text'
|
||||
JSON = 'json'
|
||||
54
src/application/domain/enums/log_level.py
Normal file
54
src/application/domain/enums/log_level.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARNING = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
EXCEPTION = 60
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"[{self.value}, '{self.name}']"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value == other.value
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value < other.value
|
||||
if isinstance(other, int):
|
||||
return self.value < other
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value <= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value <= other
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value > other.value
|
||||
if isinstance(other, int):
|
||||
return self.value > other
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value >= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value >= other
|
||||
return NotImplemented
|
||||
1
src/application/domain/exceptions/__init__.py
Normal file
1
src/application/domain/exceptions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.domain.exceptions.application_exceptions import ApplicationException
|
||||
18
src/application/domain/exceptions/application_exceptions.py
Normal file
18
src/application/domain/exceptions/application_exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class ApplicationException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status_code}: {self.message}"
|
||||
0
src/command/__init__.py
Normal file
0
src/command/__init__.py
Normal file
29
src/command/create_order_command.py
Normal file
29
src/command/create_order_command.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
from src.presentation.schemas.order import CreateOrder
|
||||
|
||||
|
||||
class UserLoginStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
|
||||
@transactional
|
||||
async def __call__(self, payment_data: CreateOrder) -> bool:
|
||||
|
||||
|
||||
metadata: dict = {
|
||||
'user_id': str(payment_data.user_id),
|
||||
}
|
||||
|
||||
|
||||
2
src/infrastructure/cache/__init__.py
vendored
Normal file
2
src/infrastructure/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.cache.client import create_redis_client
|
||||
from src.infrastructure.cache.keydb_client import KeydbCache
|
||||
16
src/infrastructure/cache/client.py
vendored
Normal file
16
src/infrastructure/cache/client.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=50,
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class KeydbCache(ICache):
|
||||
USER_PREFIX = 'user:me'
|
||||
|
||||
def __init__(self, redis_client: Redis):
|
||||
self._r = redis_client
|
||||
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl))
|
||||
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl, nx=True))
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
return (await self._r.delete(key)) > 0
|
||||
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
|
||||
if raw is None:
|
||||
return None
|
||||
return orjson.loads(raw)
|
||||
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
data = orjson.dumps({
|
||||
'id': user.id,
|
||||
'email': user.email,
|
||||
'first_name': user.first_name,
|
||||
'middle_name': user.middle_name,
|
||||
'last_name': user.last_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'crypto_wallet': user.crypto_wallet,
|
||||
'phone': user.phone,
|
||||
'bik': user.bik,
|
||||
'account_number': user.account_number,
|
||||
'card_number': user.card_number,
|
||||
'inn': user.inn,
|
||||
'kyc_verified': user.kyc_verified,
|
||||
'is_deleted': user.is_deleted,
|
||||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||||
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
})
|
||||
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import settings
|
||||
155
src/infrastructure/config/settings.py
Normal file
155
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
env_file = find_dotenv(".env")
|
||||
if env_file:
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
VAULT_ADDR: str = Field(default="http://localhost:8200")
|
||||
VAULT_TOKEN: str = Field(..., description="Vault token is required")
|
||||
VAULT_MOUNT_POINT: str = Field(default="secrets")
|
||||
|
||||
VAULT_JWT_KID_PATH: str = "jwt/kid"
|
||||
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
|
||||
JWT_KEYS_REFRESH_SECONDS: int = 3600
|
||||
|
||||
DATABASE_HOST: str
|
||||
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
|
||||
DATABASE_NAME: str
|
||||
DATABASE_USER: str
|
||||
DATABASE_PASSWORD: str
|
||||
|
||||
DATABASE_POOL_SIZE: int = 10
|
||||
DATABASE_MAX_OVERFLOW: int = 20
|
||||
DATABASE_POOL_TIMEOUT: int = 30
|
||||
DATABASE_POOL_RECYCLE: int = 3600
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
CSRF_SECRET_KEY: str = Field(
|
||||
default="change-me-change-me-change-me-change-me",
|
||||
min_length=32,
|
||||
)
|
||||
|
||||
CSRF_COOKIE_SECURE: bool = False
|
||||
CSRF_COOKIE_HTTPONLY: bool = True
|
||||
CSRF_COOKIE_SAMESITE: Literal["Lax", "Strict", "None"] = "Lax"
|
||||
CSRF_COOKIE_PATH: str = "/"
|
||||
CSRF_COOKIE_DOMAIN: str | None = None
|
||||
|
||||
DOCS_USERNAME: str = "admin"
|
||||
DOCS_PASSWORD: str = "admin"
|
||||
|
||||
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
|
||||
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
|
||||
JWT_ISSUER: str | None = None
|
||||
JWT_AUDIENCE: str | None = None
|
||||
JWT_ALGORITHM: str = "RS256"
|
||||
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_PASSWORD: str | None = None
|
||||
REDIS_DB: int = 0
|
||||
|
||||
RABBIT_HOST: str = "localhost"
|
||||
RABBIT_PORT: int = 5672
|
||||
RABBIT_USER: str = "guest"
|
||||
RABBIT_PASSWORD: str = "guest"
|
||||
RABBIT_VHOST: str = "/"
|
||||
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code"
|
||||
|
||||
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def load_from_vault(cls, data: dict):
|
||||
addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200"
|
||||
token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN")
|
||||
mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets"
|
||||
|
||||
if not token:
|
||||
raise RuntimeError("VAULT_TOKEN is required")
|
||||
|
||||
client = create_hvac_client(url=addr, token=token, timeout=5)
|
||||
|
||||
def safe_read(path: str) -> dict:
|
||||
try:
|
||||
return read_kv2_secret(client=client, mount_point=mount, path=path)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
database = safe_read("database")
|
||||
rabbitmq = safe_read("rabbitmq")
|
||||
csrf = safe_read("csrf")
|
||||
|
||||
if database:
|
||||
required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"]
|
||||
missing = [k for k in required if k not in database]
|
||||
if missing:
|
||||
raise RuntimeError(f"Vault database secret missing keys {missing}")
|
||||
|
||||
data["DATABASE_HOST"] = database["HOST"]
|
||||
data["DATABASE_PORT"] = database["PORT"]
|
||||
data["DATABASE_NAME"] = database["NAME"]
|
||||
data["DATABASE_USER"] = database["USER"]
|
||||
data["DATABASE_PASSWORD"] = database["PASSWORD"]
|
||||
|
||||
if rabbitmq:
|
||||
data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST"))
|
||||
data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT"))
|
||||
data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER"))
|
||||
data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD"))
|
||||
data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST"))
|
||||
|
||||
if csrf:
|
||||
data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY"))
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
|
||||
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||
)
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
|
||||
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
|
||||
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
|
||||
|
||||
@property
|
||||
def EXCLUDED_PATHS(self) -> List[str]:
|
||||
return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
1
src/infrastructure/context_vars/__init__.py
Normal file
1
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.context_vars.trace_id import trace_id_var
|
||||
4
src/infrastructure/context_vars/trace_id.py
Normal file
4
src/infrastructure/context_vars/trace_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')
|
||||
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
1
src/infrastructure/database/decorators/__init__.py
Normal file
1
src/infrastructure/database/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.decorators.transactional import transactional
|
||||
15
src/infrastructure/database/decorators/transactional.py
Normal file
15
src/infrastructure/database/decorators/transactional.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, TypeVar, ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(method)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
async with self._unit_of_work:
|
||||
return await method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
6
src/infrastructure/database/models/__init__.py
Normal file
6
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.sessions import Session
|
||||
|
||||
__all__ = ['Base', 'UserModel', 'Session']
|
||||
|
||||
19
src/infrastructure/database/models/base.py
Normal file
19
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
|
||||
for col in self.__table__.columns)
|
||||
return f"<{class_name}({attributes})>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
|
||||
for col in self.__table__.columns
|
||||
if getattr(self, col.name) is not None)
|
||||
return f"{class_name}({attributes})"
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
16
src/infrastructure/database/models/mixins/audit.py
Normal file
16
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
|
||||
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))
|
||||
50
src/infrastructure/database/models/sessions.py
Normal file
50
src/infrastructure/database/models/sessions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
from src.infrastructure.database.models import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
|
||||
|
||||
|
||||
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
sid: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
default=lambda: str(ULID()),
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
first_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
last_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
last_seen_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
|
||||
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
|
||||
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)
|
||||
28
src/infrastructure/database/models/user.py
Normal file
28
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
from sqlalchemy import Boolean, Date, String, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
|
||||
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
1
src/infrastructure/database/repositories/__init__.py
Normal file
1
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
114
src/infrastructure/database/repositories/user_repository.py
Normal file
114
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def create_user(self, email: str, password_hash: str) -> UserEntity:
|
||||
user = UserModel(email=email, password_hash=password_hash)
|
||||
self._session.add(user)
|
||||
try:
|
||||
await self._session.flush()
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
created_at=user.created_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted
|
||||
)
|
||||
|
||||
except IntegrityError:
|
||||
self._logger.error(f'User already exists with email {user.email}')
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
message='User with this email already exists',
|
||||
)
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserEntity:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
user: UserModel | None = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
self._logger.warning(f'User not found with email {email}')
|
||||
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found',)
|
||||
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel.id)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
|
||||
42
src/infrastructure/database/unit_of_work.py
Normal file
42
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._session_repository: ISessionRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository:
|
||||
if self._session_repository is None:
|
||||
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
|
||||
return self._session_repository
|
||||
28
src/infrastructure/logger/__init__.py
Normal file
28
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat
|
||||
from src.application.domain.enums import LogLevel
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
log_levels = {
|
||||
'DEBUG': LogLevel.DEBUG,
|
||||
'INFO': LogLevel.INFO,
|
||||
'WARNING': LogLevel.WARNING,
|
||||
'ERROR': LogLevel.ERROR,
|
||||
'CRITICAL': LogLevel.CRITICAL,
|
||||
'EXCEPTION': LogLevel.EXCEPTION,
|
||||
}
|
||||
|
||||
log_formats = {
|
||||
'JSON': LogFormat.JSON,
|
||||
'TEXT': LogFormat.TEXT,
|
||||
}
|
||||
|
||||
logger = Logger(
|
||||
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
|
||||
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
|
||||
)
|
||||
|
||||
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = json.dumps(log_data, ensure_ascii=False)
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
3
src/infrastructure/security/__init__.py
Normal file
3
src/infrastructure/security/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.security.csrf import CsrfService
|
||||
from src.infrastructure.security.hash import HashService
|
||||
81
src/infrastructure/security/csrf.py
Normal file
81
src/infrastructure/security/csrf.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
import secrets
|
||||
from typing import Any, Optional, Mapping
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from src.application.contracts import ICsrfService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
|
||||
|
||||
class CsrfService(ICsrfService):
|
||||
COOKIE_NAME = 'csrf_token'
|
||||
HEADER_NAME = 'X-CSRF-Token'
|
||||
SALT = 'csrf'
|
||||
TTL_SECONDS = 3600
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._serializer = URLSafeTimedSerializer(
|
||||
secret_key=settings.CSRF_SECRET_KEY,
|
||||
salt=self.SALT,
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self) -> str:
|
||||
return self.COOKIE_NAME
|
||||
|
||||
@property
|
||||
def header_name(self) -> str:
|
||||
return self.HEADER_NAME
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> int:
|
||||
return self.TTL_SECONDS
|
||||
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
payload = {
|
||||
'sub': subject,
|
||||
'nonce': secrets.token_urlsafe(32),
|
||||
}
|
||||
return self._serializer.dumps(payload)
|
||||
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
try:
|
||||
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
|
||||
except SignatureExpired:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token expired',
|
||||
)
|
||||
except BadSignature:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token invalid',
|
||||
)
|
||||
|
||||
if expected_subject is not None and data.get('sub') != expected_subject:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token subject mismatch',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
cookie_token = cookies.get(self.COOKIE_NAME)
|
||||
header_token = headers.get(self.HEADER_NAME)
|
||||
return cookie_token, header_token
|
||||
|
||||
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
|
||||
if not cookie_token or not header_token:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token missing',
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token mismatch',
|
||||
)
|
||||
|
||||
self.verify(cookie_token, expected_subject=expected_subject)
|
||||
17
src/infrastructure/security/hash.py
Normal file
17
src/infrastructure/security/hash.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import bcrypt
|
||||
from src.application.contracts import IHashService, ILogger
|
||||
|
||||
|
||||
class HashService(IHashService):
|
||||
|
||||
def __init__(self, logger: ILogger):
|
||||
self._logger = logger
|
||||
|
||||
async def hash(self, value: str) -> str:
|
||||
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
|
||||
self._logger.info(f'Hash value {hashed_value.decode()}')
|
||||
return hashed_value.decode()
|
||||
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
self._logger.info(f'Hash value {hashed_value[:10]}')
|
||||
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())
|
||||
109
src/infrastructure/security/jwt.py
Normal file
109
src/infrastructure/security/jwt.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
class JwtService(IJwtService):
|
||||
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
|
||||
self._logger = logger
|
||||
self._key_store = key_store
|
||||
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'access':
|
||||
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return AccessTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
sid=str(payload['sid']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Access token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def _decode_and_verify(self, token: str) -> dict:
|
||||
kid: str | None = None
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
kid = header.get('kid')
|
||||
if not kid:
|
||||
self._logger.warning(f'JWT header missing kid header={header}')
|
||||
raise ApplicationException(status_code=401, message='Missing token header: kid')
|
||||
|
||||
received_alg = header.get('alg')
|
||||
if received_alg != settings.JWT_ALGORITHM:
|
||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token algorithm')
|
||||
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
|
||||
await self._key_store.refresh()
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Unknown token kid')
|
||||
|
||||
options = {
|
||||
'verify_signature': True,
|
||||
'verify_exp': True,
|
||||
'verify_nbf': True,
|
||||
'verify_iat': True,
|
||||
'verify_aud': bool(settings.JWT_AUDIENCE),
|
||||
'verify_iss': bool(settings.JWT_ISSUER),
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'require_sub': True,
|
||||
'leeway': 10,
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE or None,
|
||||
issuer=settings.JWT_ISSUER or None,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if 'sid' not in payload:
|
||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: sid')
|
||||
|
||||
if 'type' not in payload:
|
||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError as exception:
|
||||
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except JWTError as exception:
|
||||
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token')
|
||||
|
||||
except Exception as exception:
|
||||
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT decode failed')
|
||||
1
src/infrastructure/utils/__init__.py
Normal file
1
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.utils.instance_id import generate_instance_id
|
||||
14
src/infrastructure/utils/instance_id.py
Normal file
14
src/infrastructure/utils/instance_id.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
"""
|
||||
Generate a process-wide instance id in ULID format.
|
||||
|
||||
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
|
||||
"""
|
||||
|
||||
|
||||
return str(ULID())
|
||||
|
||||
|
||||
3
src/infrastructure/vault/__init__.py
Normal file
3
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
|
||||
from src.infrastructure.vault.keys import JwtKeyStore
|
||||
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
||||
113
src/infrastructure/vault/keys.py
Normal file
113
src/infrastructure/vault/keys.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
|
||||
_instance: 'JwtKeyStore | None' = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str,
|
||||
vault_token: str,
|
||||
mount_point: str,
|
||||
kid_path: str = 'jwt/kid',
|
||||
kids_prefix: str = 'jwt/kids',
|
||||
timeout_seconds: int = 5,
|
||||
refresh_ttl_seconds: int = 60,
|
||||
):
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self._vault_addr = vault_addr
|
||||
self._vault_token = vault_token
|
||||
self._timeout = timeout_seconds
|
||||
|
||||
self._mount = mount_point
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
|
||||
self._refresh_ttl_seconds = refresh_ttl_seconds
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._keyset: JwtPublicKeySet | None = None
|
||||
self._last_refresh_at: datetime | None = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
if cls._instance is None:
|
||||
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
|
||||
return cls._instance
|
||||
|
||||
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
||||
client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout)
|
||||
|
||||
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
||||
active_kid = kids.get('active')
|
||||
previous_kid = kids.get('previous')
|
||||
|
||||
if not active_kid:
|
||||
raise RuntimeError('Vault jwt/kid secret missing "active"')
|
||||
|
||||
active = self._read_public_key_sync(client, str(active_kid))
|
||||
|
||||
previous = None
|
||||
if previous_kid and previous_kid != active_kid:
|
||||
previous = self._read_public_key_sync(client, str(previous_kid))
|
||||
|
||||
return JwtPublicKeySet(active=active, previous=previous)
|
||||
|
||||
def _read_public_key_sync(self, client, kid: str) -> JwtPublicKey:
|
||||
data = read_kv2_secret(
|
||||
client=client,
|
||||
mount_point=self._mount,
|
||||
path=f'{self._kids_prefix}/{kid}',
|
||||
)
|
||||
pub = data.get('public_key')
|
||||
if not pub:
|
||||
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
|
||||
return JwtPublicKey(kid=kid, public_key_pem=pub)
|
||||
|
||||
async def refresh(self) -> JwtPublicKeySet:
|
||||
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
||||
async with self._lock:
|
||||
self._keyset = keyset
|
||||
self._last_refresh_at = datetime.now(timezone.utc)
|
||||
return keyset
|
||||
|
||||
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.public_keys_by_kid().get(kid)
|
||||
|
||||
async def last_refresh_at(self) -> datetime | None:
|
||||
async with self._lock:
|
||||
return self._last_refresh_at
|
||||
|
||||
async def _get_or_refresh(self) -> JwtPublicKeySet:
|
||||
async with self._lock:
|
||||
ks = self._keyset
|
||||
last = self._last_refresh_at
|
||||
|
||||
if ks is None:
|
||||
return await self.refresh()
|
||||
|
||||
if last is None:
|
||||
return await self.refresh()
|
||||
|
||||
age = (datetime.now(timezone.utc) - last).total_seconds()
|
||||
if age >= self._refresh_ttl_seconds:
|
||||
return await self.refresh()
|
||||
|
||||
return ks
|
||||
23
src/infrastructure/vault/scheduler.py
Normal file
23
src/infrastructure/vault/scheduler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
store.refresh,
|
||||
trigger=IntervalTrigger(seconds=refresh_seconds),
|
||||
id="jwt_keys_refresh",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
|
||||
return scheduler
|
||||
17
src/infrastructure/vault/utils.py
Normal file
17
src/infrastructure/vault/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import hvac
|
||||
|
||||
|
||||
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
|
||||
client = hvac.Client(url=url, token=token, timeout=timeout)
|
||||
if not client.is_authenticated():
|
||||
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
|
||||
return client
|
||||
|
||||
|
||||
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point=mount_point,
|
||||
path=path,
|
||||
)
|
||||
return secret["data"]["data"]
|
||||
115
src/main.py
Normal file
115
src/main.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI, status
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.cache import create_redis_client
|
||||
from src.infrastructure.config.settings import get_settings
|
||||
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
|
||||
from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
from src.presentation.routing import order_router
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
|
||||
if not (user_ok and pass_ok):
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message='Unauthorized',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
instance_id = generate_instance_id()
|
||||
logger.set_instance_id(instance_id)
|
||||
logger.info(f'Users service instance started with id {instance_id}')
|
||||
|
||||
app.state.redis = create_redis_client()
|
||||
|
||||
jwt_store = JwtKeyStore(
|
||||
vault_addr=settings.VAULT_ADDR,
|
||||
vault_token=settings.VAULT_TOKEN,
|
||||
mount_point=settings.VAULT_MOUNT_POINT,
|
||||
kid_path=settings.VAULT_JWT_KID_PATH,
|
||||
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
||||
)
|
||||
|
||||
await jwt_store.refresh()
|
||||
|
||||
jwt_scheduler = start_jwt_keys_scheduler(jwt_store, refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS)
|
||||
|
||||
app.state.jwt_key_store = jwt_store
|
||||
app.state.jwt_keys_scheduler = jwt_scheduler
|
||||
yield
|
||||
await app.state.redis.aclose()
|
||||
logger.info(f'Users service instance ended with id {instance_id}')
|
||||
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=None,
|
||||
lifespan=lifespan,
|
||||
title='Elcsa Users Service'
|
||||
)
|
||||
|
||||
app.add_exception_handler(ApplicationException, application_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
app.include_router(order_router)
|
||||
|
||||
|
||||
# Added middleware
|
||||
app.add_middleware(TraceIDMiddleware, logger=logger)
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts=True,
|
||||
hsts_preload=False,
|
||||
frame_options='DENY',
|
||||
referrer_policy='strict-origin-when-cross-origin',
|
||||
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
|
||||
)
|
||||
|
||||
|
||||
@app.get('/docs', include_in_schema=False)
|
||||
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom Swagger documentation, optionally protected with basic authentication.'''
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
|
||||
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
|
||||
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
|
||||
)
|
||||
|
||||
|
||||
@app.get('/redoc', include_in_schema=False)
|
||||
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
|
||||
return get_redoc_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
|
||||
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
|
||||
)
|
||||
|
||||
|
||||
@app.post('/ping')
|
||||
async def ping() -> dict[str, str]:
|
||||
return {
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
4
src/presentation/decorators/__init__.py
Normal file
4
src/presentation/decorators/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from src.presentation.decorators.csrf import csrf_protect
|
||||
from src.presentation.decorators.rate_limit import rate_limit, _email_rl_key as email_rl_key
|
||||
from src.presentation.decorators.auth import require_access_token
|
||||
from src.presentation.decorators.cache import cached
|
||||
36
src/presentation/decorators/auth.py
Normal file
36
src/presentation/decorators/auth.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from src.application.contracts import IJwtService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.domain.dto import AccessTokenPayload, AuthContext
|
||||
from src.presentation.dependencies import get_jwt_service
|
||||
|
||||
|
||||
def _extract_access_token(request: Request) -> str | None:
|
||||
token = request.cookies.get('access_token')
|
||||
|
||||
if token:
|
||||
return token
|
||||
|
||||
auth = request.headers.get('Authorization')
|
||||
if auth:
|
||||
scheme, param = get_authorization_scheme_param(auth)
|
||||
if scheme.lower() == 'bearer' and param:
|
||||
return param
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
request: Request,
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
) -> AuthContext:
|
||||
token = _extract_access_token(request)
|
||||
if not token:
|
||||
raise ApplicationException(status_code=401, message='Not authenticated')
|
||||
|
||||
payload: AccessTokenPayload = await jwt_service.decode_access_token(token)
|
||||
if payload.type != 'access':
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)
|
||||
46
src/presentation/decorators/cache.py
Normal file
46
src/presentation/decorators/cache.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.infrastructure.cache import KeydbCache
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies.cache import get_redis
|
||||
|
||||
|
||||
def cached(*, prefix: str) -> Callable:
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
logger = get_logger()
|
||||
|
||||
request = kwargs.get('request')
|
||||
if not isinstance(request, Request):
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
request = a
|
||||
break
|
||||
|
||||
auth = kwargs.get('auth')
|
||||
user_id = getattr(auth, 'user_id', None) if auth else None
|
||||
|
||||
if request is None or user_id is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
cache_key = f'{prefix}:{user_id}'
|
||||
|
||||
try:
|
||||
redis = get_redis(request)
|
||||
cache = KeydbCache(redis)
|
||||
hit = await cache.get_user(user_id)
|
||||
if hit is not None:
|
||||
logger.debug(f'Cache hit key={cache_key}')
|
||||
return ORJSONResponse(status_code=200, content=hit)
|
||||
except Exception as e:
|
||||
logger.warning(f'Cache read failed key={cache_key} error={e}')
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
61
src/presentation/decorators/csrf.py
Normal file
61
src/presentation/decorators/csrf.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, Any, Optional, Annotated
|
||||
from fastapi import Request, Header
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.security import CsrfService
|
||||
|
||||
|
||||
def csrf_protect(
|
||||
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
|
||||
):
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
|
||||
if not has_request:
|
||||
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
|
||||
|
||||
has_header = any(p.name == 'x_csrf_token' for p in params)
|
||||
if not has_header:
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
name='x_csrf_token',
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
|
||||
)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request | None = kwargs.get('request')
|
||||
if request is None:
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
raise ApplicationException(
|
||||
status_code=500,
|
||||
message='Request is required for CSRF protection',
|
||||
)
|
||||
|
||||
csrf = CsrfService()
|
||||
|
||||
cookie_token, _ = csrf.extract(request.cookies, request.headers)
|
||||
header_token = kwargs.get('x_csrf_token')
|
||||
|
||||
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
|
||||
csrf.verify_pair(cookie_token, header_token, expected_subject)
|
||||
|
||||
kwargs.pop('x_csrf_token', None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=params)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
171
src/presentation/decorators/rate_limit.py
Normal file
171
src/presentation/decorators/rate_limit.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import hashlib
|
||||
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
|
||||
from fastapi import Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies import get_redis
|
||||
|
||||
|
||||
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
|
||||
req = kwargs.get('request')
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
return a
|
||||
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
if xff:
|
||||
return xff.split(',')[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return 'unknown'
|
||||
|
||||
|
||||
_LUA_INCR_EXPIRE_TTL = '''
|
||||
local key = KEYS[1]
|
||||
local window = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, window)
|
||||
end
|
||||
|
||||
local ttl = redis.call('TTL', key)
|
||||
return { current, ttl }
|
||||
'''
|
||||
|
||||
|
||||
Scope = Literal['ip', 'device', 'user', 'key']
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder1(Protocol):
|
||||
def __call__(self, request: Request) -> str: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder3(Protocol):
|
||||
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
KeyBuilder = KeyBuilder1 | KeyBuilder3
|
||||
|
||||
|
||||
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
try:
|
||||
sig = inspect.signature(builder)
|
||||
if len(sig.parameters) >= 3:
|
||||
return builder(request, args, kwargs)
|
||||
return builder(request)
|
||||
except Exception as e:
|
||||
try:
|
||||
return builder(request, args, kwargs)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
|
||||
body = kwargs.get('body')
|
||||
if body is None and args:
|
||||
for a in args:
|
||||
if hasattr(a, 'email'):
|
||||
body = a
|
||||
break
|
||||
|
||||
email = (getattr(body, 'email', '') or '').strip().lower()
|
||||
if not email:
|
||||
email = _client_ip(request)
|
||||
|
||||
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
|
||||
return f'email:{digest}'
|
||||
|
||||
def rate_limit(
|
||||
*,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
scope: Scope = 'ip',
|
||||
key_prefix: str = 'rl',
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
fail_open: bool = True,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
||||
|
||||
if limit <= 0:
|
||||
raise ValueError('rate_limit: limit must be > 0')
|
||||
if window_seconds <= 0:
|
||||
raise ValueError('rate_limit: window_seconds must be > 0')
|
||||
if scope == 'key' and not key_builder:
|
||||
raise ValueError('rate_limit: scope="key" requires key_builder')
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any):
|
||||
request = _find_request(args, kwargs)
|
||||
logger: ILogger = get_logger()
|
||||
|
||||
if scope == 'ip':
|
||||
ident = _client_ip(request)
|
||||
elif scope == 'device':
|
||||
ident = request.cookies.get('device_id') or _client_ip(request)
|
||||
elif scope == 'user':
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = getattr(user, 'id', None) if user else None
|
||||
ident = str(user_id) if user_id else _client_ip(request)
|
||||
else:
|
||||
try:
|
||||
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit key_builder failed error={str(e)}')
|
||||
raise ApplicationException(500, 'Rate limiter key_builder failed')
|
||||
|
||||
route = request.url.path
|
||||
method = request.method
|
||||
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
|
||||
|
||||
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
|
||||
|
||||
try:
|
||||
redis: Redis = get_redis(request)
|
||||
|
||||
result = await redis.eval(
|
||||
_LUA_INCR_EXPIRE_TTL,
|
||||
1,
|
||||
redis_key,
|
||||
str(window_seconds),
|
||||
)
|
||||
|
||||
count = int(result[0])
|
||||
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
|
||||
ttl = window_seconds if ttl_raw < 0 else ttl_raw
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
|
||||
|
||||
if fail_open:
|
||||
logger.warning(f'RateLimit fail-open activated key={redis_key}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
raise ApplicationException(503, 'Rate limiter unavailable')
|
||||
|
||||
if count > limit:
|
||||
retry_after = max(ttl, 0)
|
||||
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
|
||||
raise ApplicationException(
|
||||
status_code=429,
|
||||
message='Too Many Requests',
|
||||
headers={'Retry-After': str(retry_after)},
|
||||
)
|
||||
|
||||
logger.debug(f'RateLimit passed key={redis_key} count={count}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
16
src/presentation/dependencies/__init__.py
Normal file
16
src/presentation/dependencies/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from src.presentation.dependencies.commands import (
|
||||
get_get_me_command,
|
||||
get_set_phone_command,
|
||||
get_set_crypto_wallet_start_command,
|
||||
get_set_crypto_wallet_complete_command,
|
||||
get_update_bank_details_start_command,
|
||||
get_update_bank_details_complete_command,
|
||||
get_change_password_start_command,
|
||||
get_change_password_complete_command,
|
||||
get_change_email_start_command,
|
||||
get_change_email_confirm_old_command,
|
||||
get_change_email_complete_command,
|
||||
)
|
||||
from src.presentation.dependencies.security import get_jwt_service
|
||||
from src.presentation.dependencies.cache import get_redis, get_cache
|
||||
from src.presentation.dependencies.queue_messanger import get_rabbit
|
||||
12
src/presentation/dependencies/cache.py
Normal file
12
src/presentation/dependencies/cache.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi import Depends, Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.infrastructure.cache import KeydbCache
|
||||
|
||||
|
||||
def get_redis(request: Request) -> Redis:
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
def get_cache(redis_client: Redis = Depends(get_redis)) -> ICache:
|
||||
return KeydbCache(redis_client)
|
||||
161
src/presentation/dependencies/commands.py
Normal file
161
src/presentation/dependencies/commands.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands import GetMeCommand, SetPhoneCommand, SetCryptoWalletStartCommand, SetCryptoWalletCompleteCommand, UpdateBankDetailsStartCommand, UpdateBankDetailsCompleteCommand, ChangePasswordStartCommand, ChangePasswordCompleteCommand, ChangeEmailStartCommand, ChangeEmailConfirmOldCommand, ChangeEmailCompleteCommand
|
||||
from src.application.contracts import ILogger, ICache, IQueueMessanger, IHashService
|
||||
from src.presentation.dependencies.cache import get_cache
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.dependencies.queue_messanger import get_rabbit
|
||||
from src.presentation.dependencies.security import get_hash_service
|
||||
from src.presentation.dependencies.unit_of_work import get_unit_of_work
|
||||
|
||||
|
||||
def get_get_me_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
) -> GetMeCommand:
|
||||
return GetMeCommand(logger=logger, unit_of_work=unit_of_work, cache=cache)
|
||||
|
||||
|
||||
def get_set_phone_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
) -> SetPhoneCommand:
|
||||
return SetPhoneCommand(logger=logger, unit_of_work=unit_of_work, cache=cache)
|
||||
|
||||
|
||||
def get_set_crypto_wallet_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> SetCryptoWalletStartCommand:
|
||||
return SetCryptoWalletStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_set_crypto_wallet_complete_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> SetCryptoWalletCompleteCommand:
|
||||
return SetCryptoWalletCompleteCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_change_password_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> ChangePasswordStartCommand:
|
||||
return ChangePasswordStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_change_password_complete_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> ChangePasswordCompleteCommand:
|
||||
return ChangePasswordCompleteCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_change_email_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> ChangeEmailStartCommand:
|
||||
return ChangeEmailStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_change_email_confirm_old_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> ChangeEmailConfirmOldCommand:
|
||||
return ChangeEmailConfirmOldCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_change_email_complete_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> ChangeEmailCompleteCommand:
|
||||
return ChangeEmailCompleteCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_update_bank_details_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> UpdateBankDetailsStartCommand:
|
||||
return UpdateBankDetailsStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
|
||||
|
||||
def get_update_bank_details_complete_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
cache: ICache = Depends(get_cache),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
) -> UpdateBankDetailsCompleteCommand:
|
||||
return UpdateBankDetailsCompleteCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
cache=cache,
|
||||
hash_service=hash_service,
|
||||
)
|
||||
7
src/presentation/dependencies/logger.py
Normal file
7
src/presentation/dependencies/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
@lru_cache
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
8
src/presentation/dependencies/queue_messanger.py
Normal file
8
src/presentation/dependencies/queue_messanger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.messanger import RabbitClient
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_rabbit() -> IQueueMessanger:
|
||||
return RabbitClient()
|
||||
25
src/presentation/dependencies/security.py
Normal file
25
src/presentation/dependencies/security.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from functools import lru_cache
|
||||
from fastapi import Depends
|
||||
from src.application.contracts import IJwtService, ILogger, IHashService
|
||||
from src.infrastructure.security import JwtService, HashService
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _hash_service(logger: ILogger) -> IHashService:
|
||||
return HashService(logger=logger)
|
||||
|
||||
|
||||
def get_hash_service(logger: ILogger = Depends(get_logger)) -> IHashService:
|
||||
return _hash_service(logger)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _jwt_service(logger: ILogger) -> IJwtService:
|
||||
key_store = JwtKeyStore.get_instance()
|
||||
return JwtService(logger=logger, key_store=key_store)
|
||||
|
||||
|
||||
def get_jwt_service(logger: ILogger = Depends(get_logger)) -> IJwtService:
|
||||
return _jwt_service(logger)
|
||||
10
src/presentation/dependencies/unit_of_work.py
Normal file
10
src/presentation/dependencies/unit_of_work.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database import UnitOfWork
|
||||
from src.infrastructure.database.context import async_session_maker
|
||||
from src.infrastructure.logger import get_logger
|
||||
|
||||
|
||||
def get_unit_of_work(logger: ILogger = Depends(get_logger)) -> IUnitOfWork:
|
||||
return UnitOfWork(session_factory=async_session_maker, logger=logger)
|
||||
2
src/presentation/handlers/__init__.py
Normal file
2
src/presentation/handlers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
|
||||
from src.presentation.handlers.application_handler import application_exception_handler
|
||||
17
src/presentation/handlers/application_handler.py
Normal file
17
src/presentation/handlers/application_handler.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
|
||||
detail = exc.message
|
||||
if 500 <= exc.status_code:
|
||||
detail = "Internal Server Error"
|
||||
|
||||
return ORJSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": detail},
|
||||
headers=dict(exc.headers) if exc.headers else None,
|
||||
)
|
||||
|
||||
|
||||
12
src/presentation/handlers/unhandled_handler.py
Normal file
12
src/presentation/handlers/unhandled_handler.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from starlette import status
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
|
||||
async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJSONResponse:
|
||||
logger.exception(f'Unhandled exception: {type(exc).__name__}')
|
||||
return ORJSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'detail': 'Internal Server Error'},
|
||||
)
|
||||
2
src/presentation/middleware/__init__.py
Normal file
2
src/presentation/middleware/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.middleware.trace_id import TraceIDMiddleware
|
||||
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware
|
||||
51
src/presentation/middleware/security_headers.py
Normal file
51
src/presentation/middleware/security_headers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
*,
|
||||
hsts: bool = True,
|
||||
hsts_max_age: int = 31536000, # 1 год
|
||||
hsts_include_subdomains: bool = True,
|
||||
hsts_preload: bool = False,
|
||||
frame_options: str = 'DENY', # или 'SAMEORIGIN'
|
||||
referrer_policy: str = 'strict-origin-when-cross-origin',
|
||||
content_security_policy: str | None = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.hsts = hsts
|
||||
self.hsts_max_age = hsts_max_age
|
||||
self.hsts_include_subdomains = hsts_include_subdomains
|
||||
self.hsts_preload = hsts_preload
|
||||
self.frame_options = frame_options
|
||||
self.referrer_policy = referrer_policy
|
||||
self.csp = content_security_policy
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response: Response = await call_next(request)
|
||||
|
||||
if request.url.path in ('/docs', '/redoc', '/openapi.json'):
|
||||
return response
|
||||
|
||||
if self.hsts and request.url.scheme == 'https':
|
||||
hsts = f'max-age={self.hsts_max_age}'
|
||||
if self.hsts_include_subdomains:
|
||||
hsts += '; includeSubDomains'
|
||||
if self.hsts_preload:
|
||||
hsts += '; preload'
|
||||
response.headers['Strict-Transport-Security'] = hsts
|
||||
|
||||
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||||
|
||||
response.headers['X-Frame-Options'] = self.frame_options
|
||||
|
||||
response.headers['Referrer-Policy'] = self.referrer_policy
|
||||
|
||||
if self.csp:
|
||||
response.headers['Content-Security-Policy'] = self.csp
|
||||
|
||||
return response
|
||||
135
src/presentation/middleware/trace_id.py
Normal file
135
src/presentation/middleware/trace_id.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from contextvars import Token
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class TraceIDMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
logger: ILogger,
|
||||
response_header_name: str = "X-Trace-ID",
|
||||
attach_response_header: bool = True,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.response_header_name = response_header_name
|
||||
self.attach_response_header = attach_response_header
|
||||
|
||||
def _is_excluded(self, path: str) -> bool:
|
||||
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope)
|
||||
|
||||
if self._is_excluded(request.url.path):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
|
||||
if not trace_id:
|
||||
trace_id = str(ULID())
|
||||
|
||||
request.state.trace_id = trace_id
|
||||
|
||||
token: Token = trace_id_var.set(trace_id)
|
||||
|
||||
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
|
||||
|
||||
status_code_holder: dict[str, Optional[int]] = {"status": None}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
status_code_holder["status"] = int(message["status"])
|
||||
|
||||
if self.attach_response_header:
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
status = status_code_holder["status"]
|
||||
status_part = f"{status}" if status is not None else "unknown"
|
||||
self.logger.debug(
|
||||
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
|
||||
)
|
||||
trace_id_var.reset(token)
|
||||
|
||||
|
||||
# from __future__ import annotations
|
||||
# from typing import Optional
|
||||
# from starlette.requests import Request
|
||||
# from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
# from ulid import ULID
|
||||
# from src.application.contracts import ILogger
|
||||
# from src.infrastructure.config.settings import settings
|
||||
#
|
||||
#
|
||||
# class TraceIDMiddleware:
|
||||
# def __init__(
|
||||
# self,
|
||||
# app: ASGIApp,
|
||||
# logger: ILogger,
|
||||
# response_header_name: str = 'X-Trace-ID',
|
||||
# attach_response_header: bool = True,
|
||||
# ) -> None:
|
||||
# self.app = app
|
||||
# self.logger = logger
|
||||
# self.response_header_name = response_header_name
|
||||
# self.attach_response_header = attach_response_header
|
||||
#
|
||||
# def _is_excluded(self, path: str) -> bool:
|
||||
# return any(path == p or path.startswith(p.rstrip('/') + '/') for p in settings.EXCLUDED_PATHS)
|
||||
#
|
||||
# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# if scope['type'] != 'http':
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# request = Request(scope)
|
||||
#
|
||||
# if self._is_excluded(request.url.path):
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# trace_id = request.headers.get('X-Trace-ID') or request.headers.get('X-Request-ID')
|
||||
# if not trace_id:
|
||||
# trace_id = str(ULID())
|
||||
#
|
||||
# request.state.trace_id = trace_id
|
||||
# self.logger.set_trace_id(trace_id)
|
||||
#
|
||||
# self.logger.debug(f'Request started: {request.method} {request.url} - TraceID: {trace_id}')
|
||||
#
|
||||
# status_code_holder: dict[str, Optional[int]] = {'status': None}
|
||||
#
|
||||
# async def send_wrapper(message: Message) -> None:
|
||||
# if message['type'] == 'http.response.start':
|
||||
# status_code_holder['status'] = int(message['status'])
|
||||
#
|
||||
# if self.attach_response_header:
|
||||
# headers = list(message.get('headers', []))
|
||||
# headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
# message['headers'] = headers
|
||||
# await send(message)
|
||||
#
|
||||
# try:
|
||||
# await self.app(scope, receive, send_wrapper)
|
||||
# finally:
|
||||
# status = status_code_holder['status']
|
||||
# status_part = f'{status}' if status is not None else 'unknown'
|
||||
# self.logger.debug(f'Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}')
|
||||
# self.logger.clear_trace_id()
|
||||
1
src/presentation/routing/__init__.py
Normal file
1
src/presentation/routing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.routing.order import order_router
|
||||
114
src/presentation/routing/order.py
Normal file
114
src/presentation/routing/order.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
from decimal import Decimal
|
||||
from urllib.parse import parse_qs
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.dto import AuthContext
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.presentation.decorators import csrf_protect, require_access_token
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.schemas.order import CreateOrder
|
||||
|
||||
|
||||
order_router = APIRouter(prefix='/order', tags=['orders'])
|
||||
|
||||
ITPAY_API_BASE = 'https://api.gw.itpay.ru'
|
||||
ITPAY_AUTHORIZATION = 'Token REPLACE_WITH_JWT_FROM_ITPAY_DASHBOARD'
|
||||
HARDCODED_USDT_TO_RUB = Decimal('100')
|
||||
HARDCODED_GAS_RUB = Decimal('15')
|
||||
HARDCODED_OUR_COMMISSION_RUB = Decimal('25')
|
||||
|
||||
|
||||
def _amount_rub_for_itpay(amount_usdt: Decimal) -> Decimal:
|
||||
return (amount_usdt * HARDCODED_USDT_TO_RUB + HARDCODED_GAS_RUB + HARDCODED_OUR_COMMISSION_RUB).quantize(Decimal('0.01'))
|
||||
|
||||
|
||||
|
||||
@order_router.post('/create')
|
||||
#@csrf_protect()
|
||||
async def create_order(
|
||||
request: Request,
|
||||
body: CreateOrder,
|
||||
#auth: AuthContext = Depends(require_access_token),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> ORJSONResponse:
|
||||
amount_rub = _amount_rub_for_itpay(body.amount_usdt)
|
||||
amount_str = str(amount_rub)
|
||||
client_payment_id = str(ULID())
|
||||
payload = {
|
||||
'amount': amount_str,
|
||||
'client_payment_id': client_payment_id,
|
||||
'description': f'USDT {body.amount_usdt}',
|
||||
'metadata': {
|
||||
'user_id': '01KPSYW27JZ26HBDR3QS5J6VMS',
|
||||
'amount_usdt': str(body.amount_usdt),
|
||||
'rate': str(HARDCODED_USDT_TO_RUB),
|
||||
'gas_rub': str(HARDCODED_GAS_RUB),
|
||||
'commission_rub': str(HARDCODED_OUR_COMMISSION_RUB),
|
||||
},
|
||||
}
|
||||
url = f'{ITPAY_API_BASE}/v1/payments'
|
||||
headers = {
|
||||
'Authorization': ITPAY_AUTHORIZATION,
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=payload, headers=headers) as resp:
|
||||
response_text = await resp.text()
|
||||
try:
|
||||
response_json = json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
response_json = {'raw': response_text}
|
||||
if resp.status >= 400:
|
||||
logger.warning(f'itpay payments POST {resp.status} {response_text}')
|
||||
raise ApplicationException(status_code=502, message='Payment provider error')
|
||||
except ApplicationException:
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(str(e))
|
||||
raise ApplicationException(status_code=502, message='Payment provider unreachable')
|
||||
return ORJSONResponse(
|
||||
content={
|
||||
'itpay': response_json,
|
||||
'client_payment_id': client_payment_id,
|
||||
'amount_usdt': str(body.amount_usdt),
|
||||
'amount_rub': amount_str,
|
||||
'hardcoded': {
|
||||
'usdt_to_rub': str(HARDCODED_USDT_TO_RUB),
|
||||
'gas_rub': str(HARDCODED_GAS_RUB),
|
||||
'commission_rub': str(HARDCODED_OUR_COMMISSION_RUB),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@order_router.post('/webhook/itpay')
|
||||
async def itpay_webhook(request: Request, logger: ILogger = Depends(get_logger)) -> ORJSONResponse:
|
||||
raw = await request.body()
|
||||
ct = (request.headers.get('content-type') or '').lower()
|
||||
if 'application/json' in ct:
|
||||
try:
|
||||
parsed = json.loads(raw.decode('utf-8'))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
parsed = raw.decode('utf-8', errors='replace')
|
||||
elif 'application/x-www-form-urlencoded' in ct:
|
||||
decoded = raw.decode('utf-8', errors='replace')
|
||||
qs = parse_qs(decoded, keep_blank_values=True)
|
||||
parsed = {k: (vals[0] if len(vals) == 1 else vals) for k, vals in qs.items()}
|
||||
else:
|
||||
parsed = raw.decode('utf-8', errors='replace')
|
||||
log_payload = {
|
||||
'method': request.method,
|
||||
'url': str(request.url),
|
||||
'headers': {k: v for k, v in request.headers.items()},
|
||||
'body': parsed,
|
||||
}
|
||||
logger.info(json.dumps(log_payload, ensure_ascii=False, default=str))
|
||||
return ORJSONResponse(content={'status': 0})
|
||||
0
src/presentation/schemas/__init__.py
Normal file
0
src/presentation/schemas/__init__.py
Normal file
6
src/presentation/schemas/order.py
Normal file
6
src/presentation/schemas/order.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateOrder(BaseModel):
|
||||
amount_usdt: Decimal = Field(gt=0)
|
||||
Reference in New Issue
Block a user