This commit is contained in:
2026-06-03 13:49:16 +03:00
commit 284a5fa468
138 changed files with 6660 additions and 0 deletions

145
.gitignore vendored Normal file
View File

@@ -0,0 +1,145 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
generate_password_hash.py
# C extensions
*.so
*.pyd
*.dll
# Distribution / packaging
.Python
build/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache/
.pytest_cache/
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
# Type checkers / linters
.mypy_cache/
.dmypy.json
dmypy.json
.pyre/
.pytype/
.ruff_cache/
# Jupyter Notebook
.ipynb_checkpoints/
# Environments
.env
.env.*
.venv/
venv/
ENV/
env/
env.bak/
venv.bak/
# Poetry
poetry.lock
# Pipenv
Pipfile.lock
# Hatch
.hatch/
# pyenv
.python-version
# Logs
*.log
logs/
# Local databases
*.sqlite3
*.db
# Secrets / credentials
secrets.json
credentials.json
*.pem
*.key
*.crt
# OS generated files
.DS_Store
Thumbs.db
Desktop.ini
# PyCharm / IntelliJ IDEA
.idea/
*.iml
out/
# VS Code (optional)
.vscode/
# Temporary files
*.tmp
*.temp
*.swp
*.swo
*~
# Sphinx docs
docs/_build/
# mkdocs
site/
# celery
celerybeat-schedule
celerybeat.pid
# mypy compiled cache
.mypy_cache/
# pyinstaller
*.manifest
*.spec
# pytest debug
pytestdebug.log
# Local config overrides
config.local.py
settings.local.py
# Vault / local dev secrets
.env.vault
vault.token
.env
.dockerignore
/sql

25
Dockerfile Normal file
View File

@@ -0,0 +1,25 @@
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS builder
WORKDIR /app
COPY pyproject.toml uv.lock ./
RUN uv sync --frozen --no-dev
COPY src ./src
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS runtime
WORKDIR /app
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app/src /app/src
ENV PATH="/app/.venv/bin:$PATH" \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app
EXPOSE 8001
CMD ["sh", "-c", "python -m granian --interface asgi ${APP_MODULE:-src.main:app} --host ${APP_HOST:-0.0.0.0} --port ${APP_PORT:-8001} --workers ${APP_WORKERS:-2} --loop uvloop"]

57
docker-compose.yml Normal file
View File

@@ -0,0 +1,57 @@
services:
b2b:
container_name: b2b-service
build:
context: .
dockerfile: Dockerfile
ports:
- "8001:8001"
environment:
PYTHONUNBUFFERED: "1"
APP_MODULE: "src.main:app"
APP_HOST: "0.0.0.0"
APP_PORT: "8001"
APP_WORKERS: "2"
env_file:
- .env
depends_on:
b2b_keydb:
condition: service_healthy
restart: no
b2b_keydb:
image: eqalpha/keydb
container_name: b2b_keydb
restart: no
expose:
- "6379"
volumes:
- b2b_keydb_data:/data
command:
- keydb-server
- --requirepass
- ${REDIS_PASSWORD}
- --dir
- /data
- --appendonly
- "yes"
- --appendfsync
- everysec
- --save
- "900"
- "1"
- --save
- "300"
- "10"
- --save
- "60"
- "10000"
healthcheck:
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
interval: 5s
timeout: 2s
retries: 20
volumes:
b2b_keydb_data:

20
pyproject.toml Normal file
View File

@@ -0,0 +1,20 @@
[project]
name = "b2b-service"
version = "0.1.0"
description = "B2B purchase requests API for legal entity client users"
requires-python = "==3.12.*"
dependencies = [
"apscheduler==3.11.2",
"asyncpg==0.31.0",
"dotenv==0.9.9",
"fastapi==0.128.7",
"pydantic-settings==2.12.0",
"python-jose==3.5.0",
"python-ulid==3.1.0",
"sqlalchemy==2.0.46",
"uvloop==0.22.1; platform_system != 'Windows'",
"granian==2.6.1",
"hvac==2.4.0",
"redis==7.2.0",
"orjson==3.11.7",
]

View File

@@ -0,0 +1 @@
from src.application.abstractions.i_unit_of_work import IUnitOfWork

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
from typing import Protocol, runtime_checkable
from src.application.abstractions.repositories import (
IUserRepository,
ILegalEntityRepository,
IPurchaseRequestRepository,
)
@runtime_checkable
class IUnitOfWork(Protocol):
async def __aenter__(self) -> "IUnitOfWork": ...
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
@property
def user_repository(self) -> IUserRepository: ...
@property
def legal_entity_repository(self) -> ILegalEntityRepository: ...
@property
def purchase_request_repository(self) -> IPurchaseRequestRepository: ...

View File

@@ -0,0 +1,3 @@
from src.application.abstractions.repositories.i_user_repository import IUserRepository
from src.application.abstractions.repositories.i_legal_entity_repository import ILegalEntityRepository
from src.application.abstractions.repositories.i_purchase_request_repository import IPurchaseRequestRepository

View File

@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
from src.application.domain.entities.legal_entity import LegalEntityEntity
class ILegalEntityRepository(ABC):
@abstractmethod
async def get_by_user_id(self, user_id: str) -> LegalEntityEntity | None:
raise NotImplementedError

View File

@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
class IPurchaseRequestRepository(ABC):
@abstractmethod
async def create(
self,
*,
organization_id: str,
usdt_amount: Decimal,
comment: str | None,
target_wallet_chain: str | None,
target_wallet_address: str | None,
) -> PurchaseRequestEntity:
raise NotImplementedError
@abstractmethod
async def get_by_id(self, request_id: str) -> PurchaseRequestEntity:
raise NotImplementedError
@abstractmethod
async def list_by_organization(
self,
*,
organization_id: str,
status: str | None,
limit: int,
offset: int,
) -> list[PurchaseRequestEntity]:
raise NotImplementedError
@abstractmethod
async def count_by_organization(self, *, organization_id: str, status: str | None) -> int:
raise NotImplementedError

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from src.application.domain.entities import SessionEntity
class ISessionRepository(ABC):
@abstractmethod
async def get_by_sid(self, sid: str) -> SessionEntity | None:
raise NotImplementedError
@abstractmethod
async def get_by_user_device(self, user_id: str, device_id: str) -> SessionEntity | None:
raise NotImplementedError
@abstractmethod
async def upsert_by_device(
self,
user_id: str,
device_id: str,
sid: str,
refresh_jti_hash: str,
refresh_expires_at: datetime,
user_agent: str | None,
ip: str | None,
now: datetime,
) -> SessionEntity:
raise NotImplementedError
@abstractmethod
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
raise NotImplementedError
@abstractmethod
async def rotate_refresh(
self,
sid: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> None:
raise NotImplementedError
@abstractmethod
async def rotate_refresh_if_match(
self,
*,
sid: str,
old_jti_hash: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from src.application.domain.entities import UserEntity
class IUserRepository(ABC):
@abstractmethod
async def get_user_by_id(self, user_id: str) -> UserEntity:
raise NotImplementedError

View File

@@ -0,0 +1,5 @@
from src.application.commands.purchase_request_commands import (
CreatePurchaseRequestCommand,
GetPurchaseRequestCommand,
ListPurchaseRequestsCommand,
)

View File

@@ -0,0 +1,63 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class ChangeEmailCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(self, *, user_id: str, code: str) -> bool:
code = (code or '').strip()
NEW_USER_PREFIX = 'change_email:new_user:'
NEW_CODE_PREFIX = 'change_email:new_code:'
new_user_key = f'{NEW_USER_PREFIX}{user_id}'
new_code_key = f'{NEW_CODE_PREFIX}{code}'
cached_user_id = await self._cache.get(new_code_key)
if not cached_user_id:
self._logger.info(f'Change email complete failed: code not found (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_user_id != user_id:
self._logger.info(f'Change email complete failed: code-user mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
raw_value = await self._cache.get(new_user_key)
if not raw_value:
self._logger.info(f'Change email complete failed: user key missing (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
separator_idx = raw_value.index(':')
code_hash = raw_value[:separator_idx]
new_email = raw_value[separator_idx + 1:]
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Change email complete failed: code hash mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
user = await self._unit_of_work.user_repository.set_email(user_id=user_id, email=new_email)
await self._cache.set_user(user_id, user)
try:
await self._cache.delete(new_code_key)
await self._cache.delete(new_user_key)
except Exception as e:
self._logger.warning(f'Change email complete cleanup failed (user_id={user_id}): {e}')
self._logger.info(f'Email changed for user_id={user_id}')
return True

View File

@@ -0,0 +1,145 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class ChangeEmailConfirmOldCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, *, user_id: str, code: str, new_email: str) -> bool:
TTL = 300
MAX_ATTEMPTS = 20
OLD_USER_PREFIX = 'change_email:old_user:'
OLD_CODE_PREFIX = 'change_email:old_code:'
NEW_USER_PREFIX = 'change_email:new_user:'
NEW_CODE_PREFIX = 'change_email:new_code:'
code = (code or '').strip()
old_user_key = f'{OLD_USER_PREFIX}{user_id}'
old_code_key = f'{OLD_CODE_PREFIX}{code}'
cached_user_id = await self._cache.get(old_code_key)
if not cached_user_id:
self._logger.info(f'Change email confirm-old failed: code not found (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_user_id != user_id:
self._logger.info(f'Change email confirm-old failed: code-user mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(old_user_key)
if not code_hash:
self._logger.info(f'Change email confirm-old failed: user key missing (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Change email confirm-old failed: code hash mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if user.email and user.email.lower() == new_email.lower():
self._logger.info(f'Change email confirm-old failed: new email same as current (user_id={user_id})')
raise ApplicationException(400, 'New email must differ from the current one')
email_taken = await self._unit_of_work.user_repository.email_exists(email=new_email)
if email_taken:
self._logger.info(f'Change email confirm-old failed: new email already taken (user_id={user_id})')
raise ApplicationException(409, 'Email already in use')
try:
await self._cache.delete(old_code_key)
await self._cache.delete(old_user_key)
except Exception as e:
self._logger.warning(f'Change email confirm-old cleanup failed (user_id={user_id}): {e}')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
new_user_key = f'{NEW_USER_PREFIX}{user_id}'
for _ in range(MAX_ATTEMPTS):
new_code = f'{secrets.randbelow(1_000_000):06d}'
new_code_key = f'{NEW_CODE_PREFIX}{new_code}'
new_code_hash = await self._hash_service.hash(new_code)
reserved = await self._cache.set_nx(new_code_key, user_id, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(new_user_key, f'{new_code_hash}:{new_email}', ttl=TTL)
if not saved:
await self._cache.delete(new_code_key)
self._logger.error(f'Change email confirm-old failed: cannot save new code hash for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': new_email,
'code': new_code,
'ttl_seconds': TTL,
}
message = {
'event': 'change_email_new',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Change email new code created for user_id={user_id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(new_user_key)
await self._cache.delete(new_code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
self._logger.error(f'Failed to publish change email new code for user_id={user_id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Change email confirm-old failed: code space exhausted for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')

View File

@@ -0,0 +1,126 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class ChangeEmailStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, user_id: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
USER_PREFIX = 'change_email:old_user:'
CODE_PREFIX = 'change_email:old_code:'
LOCK_PREFIX = 'change_email:lock:'
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if not user.email:
self._logger.warning(f'User {user_id} does not have an email address')
raise ApplicationException(404, f'User {user_id} does not have an email address')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{user_id}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Change email throttled by lock (user_id={user_id})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
user_key = f'{USER_PREFIX}{user_id}'
existing = await self._cache.get(user_key)
if existing:
self._logger.info(f'Change email denied: code already exists for user_id={user_id}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Change email failed: cannot save code hash for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': user.email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'change_email_old',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Change email old code created for user_id={user_id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(user_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
self._logger.error(f'Failed to publish change email old code for user_id={user_id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Change email failed: code space exhausted for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,81 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class ChangePasswordCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(
self,
*,
user_id: str,
code: str,
new_password: str,
confirm_password: str,
) -> bool:
code = (code or '').strip()
USER_PREFIX = 'change_password:user:'
CODE_PREFIX = 'change_password:code:'
user_key = f'{USER_PREFIX}{user_id}'
code_key = f'{CODE_PREFIX}{code}'
if new_password != confirm_password:
self._logger.info(f'Change password failed: passwords do not match (user_id={user_id})')
raise ApplicationException(400, 'Passwords do not match')
cached_user_id = await self._cache.get(code_key)
if not cached_user_id:
self._logger.info(f'Change password failed: code not found (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_user_id != user_id:
self._logger.info(f'Change password failed: code-user mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(user_key)
if not code_hash:
self._logger.info(f'Change password failed: user key missing (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Change password failed: code hash mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
current_password_hash = await self._unit_of_work.user_repository.get_password_hash(user_id=user_id)
is_same = await self._hash_service.verify(hashed_value=current_password_hash, plain_value=new_password)
if is_same:
self._logger.info(f'Change password failed: new password same as current (user_id={user_id})')
raise ApplicationException(400, 'New password must differ from the current one')
new_password_hash = await self._hash_service.hash(new_password)
user = await self._unit_of_work.user_repository.set_password(
user_id=user_id,
password_hash=new_password_hash,
)
await self._cache.set_user(user_id, user)
try:
await self._cache.delete(code_key)
await self._cache.delete(user_key)
except Exception as e:
self._logger.warning(f'Change password cleanup failed (user_id={user_id}): {e}')
self._logger.info(f'Password changed for user_id={user_id}')
return True

View File

@@ -0,0 +1,126 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class ChangePasswordStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, user_id: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
USER_PREFIX = 'change_password:user:'
CODE_PREFIX = 'change_password:code:'
LOCK_PREFIX = 'change_password:lock:'
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if not user.email:
self._logger.warning(f'User {user_id} does not have an email address')
raise ApplicationException(404, f'User {user_id} does not have an email address')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{user_id}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Change password throttled by lock (user_id={user_id})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
user_key = f'{USER_PREFIX}{user_id}'
existing = await self._cache.get(user_key)
if existing:
self._logger.info(f'Change password denied: code already exists for user_id={user_id}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Change password failed: cannot save code hash for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': user.email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'change_password',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Change password code created for user_id={user_id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(user_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
self._logger.error(f'Failed to publish change password email for user_id={user_id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Change password failed: code space exhausted for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from botocore.exceptions import ClientError
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ICache, ILogger, IS3
from src.application.domain.entities import UserEntity
from src.infrastructure.database.decorators import transactional
class DeleteAvatarCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache, s3: IS3):
self._unit_of_work = unit_of_work
self._logger = logger
self._cache = cache
self._s3 = s3
@transactional
async def _load_user(self, user_id: str) -> UserEntity:
user = await self._unit_of_work.user_repository.get_user_by_id(user_id)
self._logger.debug(f'DeleteAvatar _load_user user_id={user_id} has_avatar_link={bool(user.avatar_link)}')
return user
async def __call__(self, user_id: str) -> UserEntity:
prior = await self._load_user(user_id)
link = prior.avatar_link
self._logger.info(f'DeleteAvatar start user_id={user_id} had_link={bool(link)}')
if link:
key = self._s3.object_key_from_public_url(link)
self._logger.debug(f'DeleteAvatar parsed_object_key user_id={user_id} has_key={bool(key)}')
if not key:
self._logger.warning(
f'DeleteAvatar could not parse avatar URL for S3 user_id={user_id} link_len={len(link)}'
)
if key:
self._logger.info(f'DeleteAvatar S3 delete start user_id={user_id} key={key}')
try:
await self._s3.delete_object(key=key)
self._logger.info(f'DeleteAvatar S3 delete done user_id={user_id} key={key}')
except ClientError as exc:
code = exc.response.get('Error', {}).get('Code', '')
if code not in ('NoSuchKey', '404'):
self._logger.warning(f'DeleteAvatar S3 delete failed user_id={user_id} code={code}: {exc}')
else:
self._logger.debug(f'DeleteAvatar S3 object already absent user_id={user_id} code={code}')
user = await self._clear_avatar_link(user_id)
self._logger.debug(f'DeleteAvatar DB cleared user_id={user_id} entity_has_link={bool(user.avatar_link)}')
await self._cache.set_user(user_id, user)
self._logger.debug(f'DeleteAvatar cache updated user_id={user_id}')
self._logger.info(f'Avatar removed user_id={user_id}')
return user
@transactional
async def _clear_avatar_link(self, user_id: str) -> UserEntity:
self._logger.debug(f'DeleteAvatar DB transaction set_avatar_link user_id={user_id} link=None')
return await self._unit_of_work.user_repository.set_avatar_link(user_id, None)

View File

@@ -0,0 +1,83 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class ForgotPasswordCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._cache = cache
self._logger = logger
@staticmethod
def _normalize_email(email: str) -> str:
return email.strip().lower()
@transactional
async def __call__(
self,
*,
email: str,
code: str,
new_password: str,
confirm_password: str,
) -> bool:
code = (code or '').strip()
normalized = self._normalize_email(email)
EMAIL_PREFIX = 'forgot_password:email:'
CODE_PREFIX = 'forgot_password:code:'
if new_password != confirm_password:
self._logger.info('Forgot password failed: passwords do not match')
raise ApplicationException(400, 'Passwords do not match')
code_key = f'{CODE_PREFIX}{code}'
cached_email = await self._cache.get(code_key)
if not cached_email:
self._logger.info('Forgot password failed: code not found')
raise ApplicationException(400, 'Invalid or expired code')
if cached_email != normalized:
self._logger.info('Forgot password failed: code-email mismatch')
raise ApplicationException(400, 'Invalid or expired code')
email_key = f'{EMAIL_PREFIX}{normalized}'
code_hash = await self._cache.get(email_key)
if not code_hash:
self._logger.info('Forgot password failed: email key missing')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info('Forgot password failed: code hash mismatch')
raise ApplicationException(400, 'Invalid or expired code')
user = await self._unit_of_work.user_repository.get_user_by_email(normalized)
if user is None:
self._logger.info('Forgot password failed: user not found after valid code')
raise ApplicationException(400, 'Invalid or expired code')
new_password_hash = await self._hash_service.hash(new_password)
user = await self._unit_of_work.user_repository.set_password(
user_id=user.id,
password_hash=new_password_hash,
)
await self._cache.set_user(user.id, user)
try:
await self._cache.delete(code_key)
await self._cache.delete(email_key)
except Exception as e:
self._logger.warning(f'Forgot password cleanup failed (user_id={user.id}): {e}')
self._logger.info(f'Password reset via forgot flow for user_id={user.id}')
return True

View File

@@ -0,0 +1,132 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class ForgotPasswordStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@staticmethod
def _normalize_email(email: str) -> str:
return email.strip().lower()
@transactional
async def __call__(self, email: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
EMAIL_PREFIX = 'forgot_password:email:'
CODE_PREFIX = 'forgot_password:code:'
LOCK_PREFIX = 'forgot_password:lock:'
normalized = self._normalize_email(email)
user = await self._unit_of_work.user_repository.get_user_by_email(normalized)
if user is None:
self._logger.info(f'Forgot password start: no user for email hash lookup')
return True
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{normalized}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Forgot password throttled by lock (user_id={user.id})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
email_key = f'{EMAIL_PREFIX}{normalized}'
existing = await self._cache.get(email_key)
if existing:
self._logger.info(f'Forgot password denied: code already exists for user_id={user.id}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, normalized, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Forgot password failed: cannot save code hash for user_id={user.id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': normalized,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'forgot_password',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Forgot password code created for user_id={user.id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(email_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(
f'Publish failed and rollback cache failed for user_id={user.id}: {str(rollback_err)}'
)
self._logger.error(f'Failed to publish forgot password email for user_id={user.id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Forgot password failed: code space exhausted for user_id={user.id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,17 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ILogger, ICache
from src.application.domain.entities import UserEntity
from src.infrastructure.database.decorators import transactional
class GetMeCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache):
self._unit_of_work = unit_of_work
self._logger = logger
self._cache = cache
@transactional
async def __call__(self, user_id: str) -> UserEntity:
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
self._logger.info(f'User ID: {user.id}')
return user

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from src.application.abstractions import IUnitOfWork
from src.application.domain.entities.legal_entity import LegalEntityEntity
from src.application.domain.enums.account_type import AccountType
from src.application.domain.exceptions import ForbiddenException, NotFoundException
async def require_legal_entity(user_id: str, unit_of_work: IUnitOfWork) -> LegalEntityEntity:
user = await unit_of_work.user_repository.get_user_by_id(user_id)
if user.account_type != AccountType.LEGAL_ENTITY:
raise ForbiddenException(message='B2B access is available for legal entity accounts only')
legal_entity = await unit_of_work.legal_entity_repository.get_by_user_id(user_id)
if legal_entity is None:
raise NotFoundException(message='Legal entity profile not found')
return legal_entity

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from decimal import Decimal
from src.application.abstractions import IUnitOfWork
from src.application.commands.legal_entity_guard import require_legal_entity
from src.application.contracts import ILogger
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
from src.application.domain.exceptions import NotFoundException
from src.infrastructure.database.decorators import transactional
class CreatePurchaseRequestCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
self._unit_of_work = unit_of_work
self._logger = logger
@transactional
async def __call__(
self,
user_id: str,
*,
usdt_amount: Decimal,
comment: str | None = None,
target_wallet_chain: str | None = None,
target_wallet_address: str | None = None,
) -> PurchaseRequestEntity:
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
item = await self._unit_of_work.purchase_request_repository.create(
organization_id=legal_entity.id,
usdt_amount=usdt_amount,
comment=comment,
target_wallet_chain=target_wallet_chain,
target_wallet_address=target_wallet_address,
)
self._logger.info(f'Purchase request created id={item.id} organization_id={legal_entity.id}')
return item
class ListPurchaseRequestsCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
self._unit_of_work = unit_of_work
self._logger = logger
@transactional
async def __call__(
self,
user_id: str,
*,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[PurchaseRequestEntity], int]:
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
items = await self._unit_of_work.purchase_request_repository.list_by_organization(
organization_id=legal_entity.id,
status=status,
limit=limit,
offset=offset,
)
total = await self._unit_of_work.purchase_request_repository.count_by_organization(
organization_id=legal_entity.id,
status=status,
)
return items, total
class GetPurchaseRequestCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
self._unit_of_work = unit_of_work
self._logger = logger
@transactional
async def __call__(self, user_id: str, request_id: str) -> PurchaseRequestEntity:
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
item = await self._unit_of_work.purchase_request_repository.get_by_id(request_id)
if item.organization_id != legal_entity.id:
raise NotFoundException(message='Purchase request not found')
return item

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
from datetime import datetime, timezone
from PIL import UnidentifiedImageError
from ulid import ULID
from botocore.exceptions import ClientError
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ICache, ILogger, IS3
from src.application.domain.entities import UserEntity
from src.application.domain.exceptions import BadRequestException, ServiceUnavailableException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
from src.infrastructure.media.webp import image_bytes_to_webp
class SetAvatarCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache, s3: IS3):
self._unit_of_work = unit_of_work
self._logger = logger
self._cache = cache
self._s3 = s3
@transactional
async def _load_user(self, user_id: str) -> UserEntity:
user = await self._unit_of_work.user_repository.get_user_by_id(user_id)
self._logger.debug(f'Avatar _load_user user_id={user_id} has_avatar_link={bool(user.avatar_link)}')
return user
async def __call__(self, user_id: str, image_bytes: bytes) -> tuple[UserEntity, int]:
prior = await self._load_user(user_id)
old_link = prior.avatar_link
self._logger.info(
f'SetAvatar start user_id={user_id} input_bytes={len(image_bytes)} had_previous_link={bool(old_link)}'
)
try:
webp_bytes = image_bytes_to_webp(image_bytes)
except UnidentifiedImageError as exc:
raise BadRequestException(message='Unsupported or corrupt image') from exc
except Exception as exc:
self._logger.exception(str(exc))
raise BadRequestException(message='Could not process image') from exc
self._logger.debug(f'SetAvatar webp_ready bytes={len(webp_bytes)}')
pid = user_id.replace('/', '').replace('.', '_')
name_id = str(ULID())
ts = int(datetime.now(timezone.utc).timestamp() * 1000)
prefix = settings.S3_AVATAR_KEY_PREFIX.strip().strip('/')
fname = f'{name_id}_{pid}_{ts}.webp'
object_key = f'{prefix}/{fname}' if prefix else fname
self._logger.info(f'SetAvatar S3 upload start user_id={user_id} key={object_key} webp_bytes={len(webp_bytes)}')
try:
url = await self._s3.upload_bytes(key=object_key, body=webp_bytes, content_type='image/webp')
except ClientError as exc:
self._logger.exception(str(exc))
raise ServiceUnavailableException(message='S3 upload failed') from exc
self._logger.info(f'SetAvatar S3 upload done user_id={user_id} key={object_key} public_url_len={len(url)}')
user = await self._save_avatar_link(user_id, url)
self._logger.info(
f'SetAvatar DB updated user_id={user_id} key={object_key} '
f'entity_avatar_link_len={len(user.avatar_link or "")}'
)
await self._cache.set_user(user_id, user)
self._logger.debug(f'SetAvatar cache updated user_id={user_id}')
if old_link:
old_key = self._s3.object_key_from_public_url(old_link)
if not old_key:
self._logger.warning(
f'SetAvatar could not parse old avatar URL for S3 delete user_id={user_id} '
f'old_link_len={len(old_link)}'
)
elif old_key == object_key:
self._logger.debug(f'SetAvatar skip delete same object key user_id={user_id} key={object_key}')
else:
self._logger.info(f'SetAvatar S3 delete old object user_id={user_id} old_key={old_key}')
try:
await self._s3.delete_object(key=old_key)
self._logger.info(f'SetAvatar S3 old object removed user_id={user_id} old_key={old_key}')
except ClientError as exc:
code = exc.response.get('Error', {}).get('Code', '')
if code not in ('NoSuchKey', '404'):
self._logger.warning(f'S3 delete old avatar failed user_id={user_id} code={code}: {exc}')
else:
self._logger.debug(f'SetAvatar old object already gone user_id={user_id} code={code}')
self._logger.info(f'Avatar set for user_id={user_id} key={object_key}')
return user, len(webp_bytes)
@transactional
async def _save_avatar_link(self, user_id: str, avatar_link: str) -> UserEntity:
self._logger.debug(f'SetAvatar DB transaction set_avatar_link user_id={user_id} link_len={len(avatar_link)}')
return await self._unit_of_work.user_repository.set_avatar_link(user_id, avatar_link)

View File

@@ -0,0 +1,68 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache
from src.application.domain.entities import UserEntity
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class SetEncryptedMnemonicCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(self, *, user_id: str, code: str, encrypted_mnemonic: str) -> UserEntity:
code = (code or '').strip()
USER_PREFIX = 'encrypted_mnemonic:user:'
CODE_PREFIX = 'encrypted_mnemonic:code:'
user_key = f'{USER_PREFIX}{user_id}'
code_key = f'{CODE_PREFIX}{code}'
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if user.encrypted_mnemonic is not None:
self._logger.info(f'Encrypted mnemonic already set for user_id={user_id}')
raise ApplicationException(409, 'Encrypted mnemonic already set and cannot be changed')
cached_user_id = await self._cache.get(code_key)
if not cached_user_id:
self._logger.info(f'Encrypted mnemonic set failed: code not found (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_user_id != user_id:
self._logger.info(f'Encrypted mnemonic set failed: code-user mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(user_key)
if not code_hash:
self._logger.info(f'Encrypted mnemonic set failed: user key missing (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Encrypted mnemonic set failed: code hash mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
user = await self._unit_of_work.user_repository.set_encrypted_mnemonic(
user_id=user_id,
encrypted_mnemonic=encrypted_mnemonic,
)
await self._cache.set_user(user_id, user)
try:
await self._cache.delete(code_key)
await self._cache.delete(user_key)
except Exception as e:
self._logger.warning(f'Encrypted mnemonic set cleanup failed (user_id={user_id}): {e}')
self._logger.info(f'Encrypted mnemonic set for user_id={user_id}')
return user

View File

@@ -0,0 +1,130 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class SetEncryptedMnemonicStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, user_id: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
USER_PREFIX = 'encrypted_mnemonic:user:'
CODE_PREFIX = 'encrypted_mnemonic:code:'
LOCK_PREFIX = 'encrypted_mnemonic:lock:'
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if user.encrypted_mnemonic is not None:
self._logger.info(f'Encrypted mnemonic already set for user_id={user_id}')
raise ApplicationException(409, 'Encrypted mnemonic already set and cannot be changed')
if not user.email:
self._logger.warning(f'User {user_id} does not have an email address')
raise ApplicationException(404, f'User {user_id} does not have an email address')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{user_id}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Encrypted mnemonic set throttled by lock (user_id={user_id})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
user_key = f'{USER_PREFIX}{user_id}'
existing = await self._cache.get(user_key)
if existing:
self._logger.info(f'Encrypted mnemonic set denied: code already exists for user_id={user_id}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Encrypted mnemonic set failed: cannot save code hash for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': user.email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'encrypted_mnemonic_set',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Encrypted mnemonic set code created for user_id={user_id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(user_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
self._logger.error(f'Failed to publish encrypted mnemonic set email for user_id={user_id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Encrypted mnemonic set failed: code space exhausted for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,18 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ILogger, ICache
from src.application.domain.entities import UserEntity
from src.infrastructure.database.decorators import transactional
class SetPhoneCommand:
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache):
self._unit_of_work = unit_of_work
self._logger = logger
self._cache = cache
@transactional
async def __call__(self, user_id: str, phone: str) -> UserEntity:
user = await self._unit_of_work.user_repository.set_phone(user_id=user_id, phone=phone)
await self._cache.set_user(user_id, user)
self._logger.info(f'Set phone for user {user_id}')
return user

View File

@@ -0,0 +1,76 @@
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache
from src.application.domain.entities import UserEntity
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class UpdateBankDetailsCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(
self,
*,
user_id: str,
code: str,
passport_data: str | None = None,
inn: str | None = None,
erc20: str | None = None,
) -> UserEntity:
code = (code or '').strip()
USER_PREFIX = 'bank_details:user:'
CODE_PREFIX = 'bank_details:code:'
user_key = f'{USER_PREFIX}{user_id}'
code_key = f'{CODE_PREFIX}{code}'
cached_user_id = await self._cache.get(code_key)
if not cached_user_id:
self._logger.info(f'Bank details update failed: code not found (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_user_id != user_id:
self._logger.info(f'Bank details update failed: code-user mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(user_key)
if not code_hash:
self._logger.info(f'Bank details update failed: user key missing (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Bank details update failed: code hash mismatch (user_id={user_id})')
raise ApplicationException(400, 'Invalid or expired code')
fields = {}
if passport_data is not None:
fields['passport_data'] = passport_data
if inn is not None:
fields['inn'] = inn
if erc20 is not None:
fields['erc20'] = erc20
user = await self._unit_of_work.user_repository.set_bank_details(user_id, **fields)
await self._cache.set_user(user_id, user)
try:
await self._cache.delete(code_key)
await self._cache.delete(user_key)
except Exception as e:
self._logger.warning(f'Bank details update cleanup failed (user_id={user_id}): {e}')
self._logger.info(f'Bank details updated for user_id={user_id}, fields={list(fields.keys())}')
return user

View File

@@ -0,0 +1,126 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class UpdateBankDetailsStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, user_id: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
USER_PREFIX = 'bank_details:user:'
CODE_PREFIX = 'bank_details:code:'
LOCK_PREFIX = 'bank_details:lock:'
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
if not user.email:
self._logger.warning(f'User {user_id} does not have an email address')
raise ApplicationException(status_code=404, message=f'User {user_id} does not have an email address')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{user_id}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Bank details update throttled by lock (user_id={user_id})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
user_key = f'{USER_PREFIX}{user_id}'
existing = await self._cache.get(user_key)
if existing:
self._logger.info(f'Bank details update denied: code already exists for user_id={user_id}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Bank details update failed: cannot save code hash for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'user-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': user.email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'bank_details_update',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'Bank details update code created for user_id={user_id}')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(user_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
self._logger.error(f'Failed to publish bank details update email for user_id={user_id}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
return True
self._logger.error(f'Bank details update failed: code space exhausted for user_id={user_id}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,7 @@
from src.application.contracts.i_logger import ILogger
from src.application.contracts.i_jwt_service import IJwtService
from src.application.contracts.i_csrf_service import ICsrfService
from src.application.contracts.i_cache import ICache
from src.application.contracts.i_hash_service import IHashService
from src.application.contracts.i_queue_messanger import IQueueMessanger
from src.application.contracts.i_s3 import IS3

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from src.application.domain.entities.user import UserEntity
class ICache(ABC):
@abstractmethod
async def set(self, key: str, value: str, ttl: int) -> bool:
raise NotImplementedError
@abstractmethod
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
raise NotImplementedError
@abstractmethod
async def get(self, key: str) -> str | None:
raise NotImplementedError
@abstractmethod
async def delete(self, key: str) -> bool:
raise NotImplementedError
@abstractmethod
async def get_user(self, user_id: str) -> dict | None:
raise NotImplementedError
@abstractmethod
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Mapping
class ICsrfService(ABC):
@abstractmethod
def issue(self, subject: Optional[str] = None) -> str:
raise NotImplementedError
@abstractmethod
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
raise NotImplementedError
@abstractmethod
def verify_pair(
self,
cookie_token: Optional[str],
header_token: Optional[str],
expected_subject: Optional[str] = None,
) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
class IHashService(ABC):
@abstractmethod
async def hash(self, value: str) -> str:
raise NotImplementedError
@abstractmethod
async def verify(self, hashed_value: str, plain_value: str) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from src.application.domain.dto import AccessTokenPayload
class IJwtService(ABC):
@abstractmethod
async def decode_access_token(self, token: str) -> AccessTokenPayload:
raise NotImplementedError

View File

@@ -0,0 +1,68 @@
from typing import Protocol, Optional, Callable
from src.application.domain.enums.log_format import LogFormat
from src.application.domain.enums.log_level import LogLevel
class ILogger(Protocol):
"""Interface for synchronous logger with ContextVar support for trace_id."""
log_format: LogFormat
min_level: LogLevel
id_generator: Optional[Callable[[], str]]
instance_id: str
def set_format(self, log_format: LogFormat) -> None:
"""Set log format using LogFormat enum"""
...
def set_min_level(self, level: LogLevel) -> None:
"""Set minimum log level"""
...
def new_trace_id(self) -> str:
"""Create and set new trace_id in context"""
...
def set_trace_id(self, trace_id: str) -> None:
"""Set existing trace_id in context"""
...
def get_trace_id(self) -> str:
"""Get current trace_id from context"""
...
def clear_trace_id(self) -> None:
"""Clear the trace_id in the context"""
...
def set_instance_id(self, instance_id: str) -> None:
"""Set service instance id (ULID recommended)"""
...
def get_instance_id(self) -> str:
"""Get current service instance id"""
...
def debug(self, message: str) -> None:
"""Log debug message"""
...
def info(self, message: str) -> None:
"""Log info message"""
...
def warning(self, message: str) -> None:
"""Log warning message"""
...
def error(self, message: str) -> None:
"""Log error message"""
...
def critical(self, message: str) -> None:
"""Log critical message"""
...
def exception(self, message: str) -> None:
"""Log exception with traceback"""
...

View File

@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from typing import Mapping, Any
class IQueueMessanger(ABC):
@abstractmethod
async def connect(self) -> None:
raise NotImplementedError
@abstractmethod
async def close(self) -> None:
raise NotImplementedError
@abstractmethod
async def publish_to_queue(
self,
queue: str,
message: Any,
*,
persist: bool = True,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
async def publish(
self,
message: Any,
*,
exchange: str,
routing_key: str,
persist: bool = True,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
from typing import Protocol, runtime_checkable
@runtime_checkable
class IS3(Protocol):
async def upload_bytes(self, *, key: str, body: bytes, content_type: str) -> str:
...
async def delete_object(self, *, key: str) -> None:
...
def object_key_from_public_url(self, url: str) -> str | None:
...

View File

@@ -0,0 +1,2 @@
from src.application.domain.dto.token import AccessTokenPayload, AuthContext
from src.application.domain.dto.keys import JwtPublicKey, JwtPublicKeySet

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import Optional, Dict
@dataclass(frozen=True)
class JwtPublicKey:
kid: str
public_key_pem: str
@dataclass(frozen=True)
class JwtPublicKeySet:
active: JwtPublicKey
previous: Optional[JwtPublicKey] = None
def public_keys_by_kid(self) -> Dict[str, str]:
out = {self.active.kid: self.active.public_key_pem}
if self.previous:
out[self.previous.kid] = self.previous.public_key_pem
return out

View File

@@ -0,0 +1,18 @@
from pydantic import BaseModel
class AccessTokenPayload(BaseModel):
sub: str
type: str
sid: str
iat: int
nbf: int
exp: int
iss: str | None = None
aud: str | None = None
class AuthContext(BaseModel):
user_id: str
sid: str
token: AccessTokenPayload

View File

@@ -0,0 +1,5 @@
from src.application.domain.entities.user import UserEntity
from src.application.domain.entities.session import SessionEntity
__all__ = ['UserEntity', 'SessionEntity']

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@dataclass(slots=True)
class LegalEntityEntity:
id: str
user_id: str
name: str
inn: str
status: str
short_name: str | None = None
ogrn: str | None = None
kpp: str | None = None
legal_address: str | None = None
actual_address: str | None = None
bank_details: dict[str, Any] | None = None
contact_person: str | None = None
contact_phone: str | None = None
kyc_verified: bool = True
kyc_verified_at: datetime | None = None

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
@dataclass(slots=True)
class PurchaseRequestEntity:
id: str
organization_id: str
status: str
usdt_amount: Decimal
rub_amount: Decimal | None
exchange_rate: Decimal | None
service_fee_percent: Decimal | None
comment: str | None
admin_comment: str | None
target_wallet_chain: str | None
target_wallet_address: str | None
tx_hash: str | None
created_at: datetime | None = None
updated_at: datetime | None = None
completed_at: datetime | None = None

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
@dataclass(slots=True)
class SessionEntity:
sid: str
user_id: str
device_id: str
revoked_at: datetime | None
last_seen_at: datetime
refresh_jti_hash: str | None
refresh_expires_at: datetime | None
user_agent: str | None = None
first_ip: str | None = None
last_ip: str | None = None

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime
@dataclass(slots=True)
class UserEntity:
id: str | None = None
email: str | None = None
password_hash: str | None = None
first_name: str | None = None
middle_name: str | None = None
last_name: str | None = None
birth_date: date | None = None
encrypted_mnemonic: str | None = None
phone: str | None = None
passport_data: str | None = None
inn: str | None = None
erc20: str | None = None
avatar_link: str | None = None
kyc_verified: bool | None = None
is_deleted: bool | None = None
account_type: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
kyc_verified_at: datetime | None = None

View File

@@ -0,0 +1,2 @@
from src.application.domain.enums.log_level import LogLevel
from src.application.domain.enums.log_format import LogFormat

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class AccountType(StrEnum):
INDIVIDUAL = 'individual'
LEGAL_ENTITY = 'legal_entity'

View File

@@ -0,0 +1,7 @@
from enum import Enum
class LogFormat(Enum):
"""Enum for supported log formats"""
TEXT = 'text'
JSON = 'json'

View File

@@ -0,0 +1,54 @@
from enum import Enum
class LogLevel(Enum):
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
CRITICAL = 50
EXCEPTION = 60
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"[{self.value}, '{self.name}']"
def __eq__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value == other.value
if isinstance(other, int):
return self.value == other
return NotImplemented
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __lt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value < other.value
if isinstance(other, int):
return self.value < other
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value <= other.value
if isinstance(other, int):
return self.value <= other
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value > other.value
if isinstance(other, int):
return self.value > other
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value >= other.value
if isinstance(other, int):
return self.value >= other
return NotImplemented

View File

@@ -0,0 +1,11 @@
from src.application.domain.exceptions.application_exceptions import (
ApplicationException,
BadRequestException,
ConflictException,
ForbiddenException,
InternalException,
NotFoundException,
ServiceUnavailableException,
TooManyRequestsException,
UnauthorizedException,
)

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from typing import Mapping
class ApplicationException(Exception):
def __init__(
self,
status_code: int,
message: str,
headers: Mapping[str, str] | None = None,
):
super().__init__(message)
self.status_code = status_code
self.message = message
self.headers = headers
def __str__(self) -> str:
return f'{self.status_code}: {self.message}'
class BadRequestException(ApplicationException):
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
super().__init__(400, message, headers)
class UnauthorizedException(ApplicationException):
def __init__(self, message: str = 'Unauthorized', headers: Mapping[str, str] | None = None):
super().__init__(401, message, headers)
class ForbiddenException(ApplicationException):
def __init__(self, message: str = 'Forbidden', headers: Mapping[str, str] | None = None):
super().__init__(403, message, headers)
class NotFoundException(ApplicationException):
def __init__(self, message: str = 'Not found', headers: Mapping[str, str] | None = None):
super().__init__(404, message, headers)
class ConflictException(ApplicationException):
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
super().__init__(409, message, headers)
class TooManyRequestsException(ApplicationException):
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
super().__init__(429, message, headers)
class ServiceUnavailableException(ApplicationException):
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
super().__init__(503, message, headers)
class InternalException(ApplicationException):
def __init__(self, message: str = 'Internal Server Error', headers: Mapping[str, str] | None = None):
super().__init__(500, message, headers)

View File

@@ -0,0 +1,21 @@
import re
SPECIAL_CHARS = '!@#$%^&*()_+-=.,:;?/[]{}<>'
def validate_password_strength(password: str) -> str:
if re.search(r'\s', password):
raise ValueError('Password must not contain whitespace')
if len(password) < 12:
raise ValueError('Password must be at least 12 characters')
if not re.search(r'[a-z]', password):
raise ValueError('Password must contain at least one lowercase letter')
if not re.search(r'[A-Z]', password):
raise ValueError('Password must contain at least one uppercase letter')
if not re.search(r'\d', password):
raise ValueError('Password must contain at least one digit')
if not any(c in SPECIAL_CHARS for c in password):
raise ValueError(
'Password must contain at least one special character from: !@#$%^&*()_+-=.,:;?/[]{}<>'
)
return password

2
src/infrastructure/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,2 @@
from src.infrastructure.cache.client import create_redis_client
from src.infrastructure.cache.keydb_client import KeydbCache

16
src/infrastructure/cache/client.py vendored Normal file
View File

@@ -0,0 +1,16 @@
import redis.asyncio as redis
from redis.asyncio.client import Redis
from src.infrastructure.config import settings
def create_redis_client() -> Redis:
return redis.from_url(
settings.REDIS_URL,
max_connections=50,
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
retry_on_timeout=True,
socket_keepalive=True,
)

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import orjson
from redis.asyncio.client import Redis
from src.application.contracts import ICache
from src.application.domain.entities.user import UserEntity
class KeydbCache(ICache):
USER_PREFIX = 'user:me'
def __init__(self, redis_client: Redis):
self._r = redis_client
async def set(self, key: str, value: str, ttl: int) -> bool:
return bool(await self._r.set(key, value, ex=ttl))
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
return bool(await self._r.set(key, value, ex=ttl, nx=True))
async def get(self, key: str) -> str | None:
return await self._r.get(key)
async def delete(self, key: str) -> bool:
return (await self._r.delete(key)) > 0
async def get_user(self, user_id: str) -> dict | None:
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
if raw is None:
return None
return orjson.loads(raw)
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
data = orjson.dumps({
'id': user.id,
'email': user.email,
'first_name': user.first_name,
'middle_name': user.middle_name,
'last_name': user.last_name,
'birth_date': str(user.birth_date) if user.birth_date else None,
'encrypted_mnemonic': user.encrypted_mnemonic,
'phone': user.phone,
'passport_data': user.passport_data,
'inn': user.inn,
'erc20': user.erc20,
'avatar_link': user.avatar_link,
'kyc_verified': user.kyc_verified,
'is_deleted': user.is_deleted,
'created_at': user.created_at.isoformat() if user.created_at else None,
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
})
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)

View File

@@ -0,0 +1 @@
from src.infrastructure.config.settings import settings

View File

@@ -0,0 +1,311 @@
from __future__ import annotations
from functools import lru_cache
from typing import Any, List, Literal, Mapping
from urllib.parse import quote
from dotenv import find_dotenv, load_dotenv
from pydantic import Field, PrivateAttr
from pydantic_settings import BaseSettings, SettingsConfigDict
from src.infrastructure.vault.client import VaultClient
env_file = find_dotenv('.env')
if env_file:
load_dotenv(env_file)
def _as_int(value: object, default: int) -> int:
if value is None:
return default
if isinstance(value, int):
return value
return int(str(value).strip())
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=True, extra='ignore')
_vault_database_secrets: dict[str, Any] = PrivateAttr(default_factory=dict)
VAULT_ADDR: str = 'https://corp.vault.elcsa.ru'
VAULT_ROLE_ID: str = ''
VAULT_SECRET_ID: str = ''
VAULT_NAMESPACE: str | None = None
VAULT_MOUNT_POINT: str = 'dev-secrets'
VAULT_DATABASE_SECRET_PATH: str = 'database'
VAULT_RABBIT_SECRET_PATH: str = 'rabbitmq'
VAULT_CSRF_SECRET_PATH: str = 'csrf'
VAULT_DOCS_SECRET_PATH: str = 'docs'
VAULT_JWT_KID_PATH: str = 'jwt/kid'
VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids'
VAULT_S3_SECRET_PATH: str = 's3/avatars'
DATABASE_URL_DIRECT: str | None = Field(default=None, validation_alias='DATABASE_URL')
DATABASE_HOST: str = ''
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
DATABASE_NAME: str = ''
DATABASE_USER: str = ''
DATABASE_PASSWORD: str = ''
DATABASE_POOL_SIZE: int = 10
DATABASE_MAX_OVERFLOW: int = 20
DATABASE_POOL_TIMEOUT: int = 30
DATABASE_POOL_RECYCLE: int = 3600
DATABASE_ECHO: bool = False
CSRF_SECRET_KEY: str = Field(
default='change-me-change-me-change-me-change-me',
min_length=32,
)
CSRF_COOKIE_SECURE: bool = False
CSRF_COOKIE_HTTPONLY: bool = True
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
CSRF_COOKIE_PATH: str = '/'
CSRF_COOKIE_DOMAIN: str | None = None
DOCS_USERNAME: str = 'admin'
DOCS_PASSWORD: str = 'admin'
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
JWT_ISSUER: str | None = None
JWT_AUDIENCE: str | None = None
JWT_ALGORITHM: str = 'RS256'
JWT_KEYS_REFRESH_SECONDS: int = 3600
REDIS_HOST: str = 'keydb'
REDIS_PORT: int = 6379
REDIS_PASSWORD: str | None = None
REDIS_DB: int = 0
RABBIT_HOST: str = 'localhost'
RABBIT_PORT: int = 5672
RABBIT_USER: str = 'guest'
RABBIT_PASSWORD: str = 'guest'
RABBIT_VHOST: str = '/'
RABBIT_PUBLISH_PERSIST: bool = True
RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
S3_BUCKET: str = ''
S3_REGION: str = 'us-east-1'
S3_ACCESS_KEY_ID: str = ''
S3_SECRET_ACCESS_KEY: str = ''
S3_ENDPOINT_URL: str = ''
S3_PUBLIC_BASE_URL: str = ''
S3_REGRU_PUBLIC_WEBSITE_HOST: bool = False
S3_AVATAR_KEY_PREFIX: str = 'avatars'
LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
LOG_FORMAT: Literal['JSON', 'TEXT'] = 'JSON'
def _get_vault_secret(self, secrets: dict[str, Any], *keys: str) -> str:
for key in keys:
value = secrets.get(key)
if value is not None and str(value).strip() != '':
return str(value)
return ''
def _reset_s3_config(self) -> None:
object.__setattr__(self, 'S3_BUCKET', '')
object.__setattr__(self, 'S3_ACCESS_KEY_ID', '')
object.__setattr__(self, 'S3_SECRET_ACCESS_KEY', '')
object.__setattr__(self, 'S3_ENDPOINT_URL', '')
object.__setattr__(self, 'S3_PUBLIC_BASE_URL', '')
object.__setattr__(self, 'S3_REGION', 'us-east-1')
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', False)
object.__setattr__(self, 'S3_AVATAR_KEY_PREFIX', 'avatars')
@staticmethod
def _vault_kv(mapping: Mapping[str, Any], *keys: str) -> Any:
for k in keys:
if k in mapping and mapping[k] is not None:
return mapping[k]
return None
def _apply_s3_from_vault_secret(self, s3: dict[str, Any]) -> None:
bucket_name = (
self._vault_kv(s3, 'bucket_name', 'BUCKET_NAME', 'bucket')
or self._vault_kv(s3, 'S3_BUCKET', 'bucketName')
)
endpoint_url = (
self._vault_kv(s3, 's3_endpoint_url', 'S3_ENDPOINT_URL', 'endpoint_url', 'ENDPOINT_URL')
or self._vault_kv(s3, 'endpoint')
)
ak = (
self._vault_kv(s3, 's3_access_key_id', 'S3_ACCESS_KEY_ID', 'ACCESS_KEY_ID', 'access_key_id')
or self._vault_kv(s3, 'AWS_ACCESS_KEY_ID')
)
sk = (
self._vault_kv(s3, 's3_secret_access_key', 'S3_SECRET_ACCESS_KEY', 'SECRET_ACCESS_KEY')
or self._vault_kv(s3, 'AWS_SECRET_ACCESS_KEY')
)
if bucket_name is None or str(bucket_name).strip() == '':
raise ValueError('Vault S3 secret must contain bucket_name')
if endpoint_url is None or str(endpoint_url).strip() == '':
raise ValueError('Vault S3 secret must contain s3_endpoint_url')
if ak is None or str(ak).strip() == '':
raise ValueError('Vault S3 secret must contain s3_access_key_id')
if sk is None or str(sk).strip() == '':
raise ValueError('Vault S3 secret must contain s3_secret_access_key')
object.__setattr__(self, 'S3_BUCKET', str(bucket_name).strip())
object.__setattr__(self, 'S3_ENDPOINT_URL', str(endpoint_url).strip())
object.__setattr__(self, 'S3_ACCESS_KEY_ID', str(ak).strip())
object.__setattr__(self, 'S3_SECRET_ACCESS_KEY', str(sk).strip())
region = (
self._vault_kv(s3, 's3_region', 'S3_REGION', 'region')
)
if region is not None and str(region).strip() != '':
object.__setattr__(self, 'S3_REGION', str(region).strip())
public_base = (
self._vault_kv(s3, 's3_public_base_url', 'S3_PUBLIC_BASE_URL', 'public_base_url')
or self._vault_kv(s3, 'public_url')
)
if public_base is not None and str(public_base).strip() != '':
object.__setattr__(self, 'S3_PUBLIC_BASE_URL', str(public_base).strip())
prefix = self._vault_kv(s3, 'avatar_key_prefix', 'S3_AVATAR_KEY_PREFIX', 's3_avatar_key_prefix')
if prefix is not None and str(prefix).strip() != '':
object.__setattr__(self, 'S3_AVATAR_KEY_PREFIX', str(prefix).strip())
rf = (
self._vault_kv(s3, 's3_reg_ru_public_website_host', 'S3_REGRU_PUBLIC_WEBSITE_HOST')
)
if rf is not None:
v = str(rf).strip().lower()
if v in {'1', 'true', 'yes', 'on'}:
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', True)
elif v in {'0', 'false', 'no', 'off'}:
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', False)
def model_post_init(self, __context: Any) -> None:
self._reset_s3_config()
if not self.VAULT_ROLE_ID.strip() or not self.VAULT_SECRET_ID.strip():
if not self.DATABASE_URL:
raise ValueError(
'Set VAULT_ROLE_ID and VAULT_SECRET_ID for Vault, or set DATABASE_URL '
'(or DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME) in the environment',
)
return
client = VaultClient(
addr=self.VAULT_ADDR,
role_id=self.VAULT_ROLE_ID,
secret_id=self.VAULT_SECRET_ID,
namespace=self.VAULT_NAMESPACE,
mount_point=self.VAULT_MOUNT_POINT,
)
db = client.read_secret(self.VAULT_DATABASE_SECRET_PATH)
object.__setattr__(self, '_vault_database_secrets', db)
def kv(d: dict[str, Any], *keys: str) -> Any:
for k in keys:
if k in d and d[k] is not None:
return d[k]
return None
if kv(db, 'HOST', 'host') is not None:
object.__setattr__(self, 'DATABASE_HOST', str(kv(db, 'HOST', 'host')))
if kv(db, 'PORT', 'port') is not None:
object.__setattr__(self, 'DATABASE_PORT', _as_int(kv(db, 'PORT', 'port'), self.DATABASE_PORT))
if kv(db, 'NAME', 'name') is not None:
object.__setattr__(self, 'DATABASE_NAME', str(kv(db, 'NAME', 'name')))
if kv(db, 'USER', 'user') is not None:
object.__setattr__(self, 'DATABASE_USER', str(kv(db, 'USER', 'user')))
if kv(db, 'PASSWORD', 'password') is not None:
object.__setattr__(self, 'DATABASE_PASSWORD', str(kv(db, 'PASSWORD', 'password')))
rabbit = client.read_secret_optional(self.VAULT_RABBIT_SECRET_PATH)
if rabbit:
if kv(rabbit, 'HOST', 'host') is not None:
object.__setattr__(self, 'RABBIT_HOST', str(kv(rabbit, 'HOST', 'host')))
if kv(rabbit, 'PORT', 'port') is not None:
object.__setattr__(self, 'RABBIT_PORT', _as_int(kv(rabbit, 'PORT', 'port'), self.RABBIT_PORT))
if kv(rabbit, 'USER', 'user') is not None:
object.__setattr__(self, 'RABBIT_USER', str(kv(rabbit, 'USER', 'user')))
if kv(rabbit, 'PASSWORD', 'password') is not None:
object.__setattr__(self, 'RABBIT_PASSWORD', str(kv(rabbit, 'PASSWORD', 'password')))
if kv(rabbit, 'VHOST', 'vhost') is not None:
object.__setattr__(self, 'RABBIT_VHOST', str(kv(rabbit, 'VHOST', 'vhost')))
csrf = client.read_secret_optional(self.VAULT_CSRF_SECRET_PATH)
if csrf and kv(csrf, 'KEY', 'key') is not None:
key = str(kv(csrf, 'KEY', 'key'))
if len(key) >= 32:
object.__setattr__(self, 'CSRF_SECRET_KEY', key)
docs = client.read_secret_optional(self.VAULT_DOCS_SECRET_PATH)
if docs:
u = docs.get('DOCS_USERNAME') or docs.get('USERNAME')
p = docs.get('DOCS_PASSWORD') or docs.get('PASSWORD')
if u is not None:
object.__setattr__(self, 'DOCS_USERNAME', str(u))
if p is not None:
object.__setattr__(self, 'DOCS_PASSWORD', str(p))
s3_rel_path = self.VAULT_S3_SECRET_PATH.strip()
if s3_rel_path:
s3_secret_data = client.read_secret_optional(s3_rel_path)
if s3_secret_data:
self._apply_s3_from_vault_secret(s3_secret_data)
if not self.DATABASE_URL:
raise ValueError('Database URL could not be built from Vault database secret')
@property
def DATABASE_URL(self) -> str:
direct = (self.DATABASE_URL_DIRECT or '').strip()
if direct:
return direct
ready_url = self._get_vault_secret(
self._vault_database_secrets,
'DATABASE_URL',
'database_url',
)
if ready_url:
return ready_url
host = self._get_vault_secret(self._vault_database_secrets, 'host', 'HOST')
port = self._get_vault_secret(self._vault_database_secrets, 'port', 'PORT') or str(self.DATABASE_PORT)
user = self._get_vault_secret(self._vault_database_secrets, 'user', 'USER')
password = self._get_vault_secret(self._vault_database_secrets, 'password', 'PASSWORD')
name = self._get_vault_secret(self._vault_database_secrets, 'name', 'NAME', 'database', 'DATABASE')
if not host or not user or not password or not name:
h = (self.DATABASE_HOST or '').strip()
u = (self.DATABASE_USER or '').strip()
p = (self.DATABASE_PASSWORD or '').strip()
n = (self.DATABASE_NAME or '').strip()
if h and u and p and n:
quoted_user = quote(u, safe='')
quoted_password = quote(p, safe='')
po = str(self.DATABASE_PORT)
return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{h}:{po}/{n}'
return ''
quoted_user = quote(user, safe='')
quoted_password = quote(password, safe='')
return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{host}:{port}/{name}'
@property
def REDIS_URL(self) -> str:
auth = f':{self.REDIS_PASSWORD}@' if self.REDIS_PASSWORD else ''
return f'redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
@property
def RABBIT_URL(self) -> str:
vhost = '%2F' if self.RABBIT_VHOST == '/' else self.RABBIT_VHOST.lstrip('/')
return f'amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}'
@property
def EXCLUDED_PATHS(self) -> List[str]:
return ['/docs', '/redoc', '/openapi.json', '/ping', '/health']
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@@ -0,0 +1 @@
from src.infrastructure.context_vars.trace_id import trace_id_var

View File

@@ -0,0 +1,4 @@
from contextvars import ContextVar
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')

View File

@@ -0,0 +1 @@
from src.infrastructure.database.unit_of_work import UnitOfWork

View File

@@ -0,0 +1,22 @@
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio.engine import create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from typing import AsyncGenerator
from src.infrastructure.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session

View File

@@ -0,0 +1 @@
from src.infrastructure.database.decorators.transactional import transactional

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from functools import wraps
from typing import Callable, Awaitable, TypeVar, ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
async with self._unit_of_work:
return await method(self, *args, **kwargs)
return wrapper

View File

@@ -0,0 +1,6 @@
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.user import UserModel
from src.infrastructure.database.models.legal_entity import LegalEntityModel
from src.infrastructure.database.models.purchase_request import PurchaseRequestModel
__all__ = ['Base', 'UserModel', 'LegalEntityModel', 'PurchaseRequestModel']

View File

@@ -0,0 +1,19 @@
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
class Base(AsyncAttrs, DeclarativeBase):
__abstract__ = True
def __repr__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
for col in self.__table__.columns)
return f"<{class_name}({attributes})>"
def __str__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
for col in self.__table__.columns
if getattr(self, col.name) is not None)
return f"{class_name}({attributes})"

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import AuditTimestampsMixin, UlidPrimaryKeyMixin
class LegalEntityModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
__tablename__ = 'legal_entities'
user_id: Mapped[str] = mapped_column(String(26), ForeignKey('users.id', ondelete='RESTRICT'), nullable=False, unique=True, index=True)
name: Mapped[str] = mapped_column(String(512), nullable=False)
short_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
inn: Mapped[str] = mapped_column(String(12), nullable=False, index=True)
ogrn: Mapped[str | None] = mapped_column(String(15), nullable=True)
kpp: Mapped[str | None] = mapped_column(String(9), nullable=True)
legal_address: Mapped[str | None] = mapped_column(Text, nullable=True)
actual_address: Mapped[str | None] = mapped_column(Text, nullable=True)
bank_details: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
contact_person: Mapped[str | None] = mapped_column(String(256), nullable=True)
contact_phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default='active', default='active')
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='true', default=True)
kyc_verified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
created_by: Mapped[str | None] = mapped_column(String(26), nullable=True)

View File

@@ -0,0 +1,3 @@
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin

View File

@@ -0,0 +1,16 @@
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
class AuditTimestampsMixin:
created_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,6 @@
from sqlalchemy import Boolean
from sqlalchemy.orm import Mapped, mapped_column
class SoftDeleteMixin:
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)

View File

@@ -0,0 +1,8 @@
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
class UlidPrimaryKeyMixin:
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from datetime import datetime
from decimal import Decimal
from sqlalchemy import DateTime, ForeignKey, Numeric, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import AuditTimestampsMixin, UlidPrimaryKeyMixin
class PurchaseRequestModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
__tablename__ = 'purchase_requests'
organization_id: Mapped[str] = mapped_column(
String(26),
ForeignKey('legal_entities.id', ondelete='RESTRICT'),
nullable=False,
index=True,
)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default='submitted', default='submitted')
usdt_amount: Mapped[Decimal] = mapped_column(Numeric(18, 8), nullable=False)
rub_amount: Mapped[Decimal | None] = mapped_column(Numeric(18, 2), nullable=True)
exchange_rate: Mapped[Decimal | None] = mapped_column(Numeric(18, 8), nullable=True)
service_fee_percent: Mapped[Decimal | None] = mapped_column(Numeric(5, 2), nullable=True)
comment: Mapped[str | None] = mapped_column(Text, nullable=True)
admin_comment: Mapped[str | None] = mapped_column(Text, nullable=True)
target_wallet_chain: Mapped[str | None] = mapped_column(String(16), nullable=True, server_default='ETH')
target_wallet_address: Mapped[str | None] = mapped_column(String(128), nullable=True)
tx_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
assigned_to: Mapped[str | None] = mapped_column(String(26), nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View File

@@ -0,0 +1,50 @@
from datetime import datetime, timezone
from sqlalchemy import String, DateTime, ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
from src.infrastructure.database.models import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
__tablename__ = "sessions"
sid: Mapped[str] = mapped_column(
String(26),
unique=True,
index=True,
nullable=False,
default=lambda: str(ULID()),
)
user_id: Mapped[str] = mapped_column(
String(26),
ForeignKey("users.id", ondelete="CASCADE"),
index=True,
nullable=False,
)
device_id: Mapped[str] = mapped_column(
String(26),
nullable=False,
index=True,
)
user_agent: Mapped[str | None] = mapped_column(String(500))
first_ip: Mapped[str | None] = mapped_column(String(64))
last_ip: Mapped[str | None] = mapped_column(String(64))
last_seen_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from sqlalchemy import Boolean, Date, DateTime, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
__tablename__ = 'users'
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
avatar_link: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None)
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
account_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default='individual', default='individual')

View File

@@ -0,0 +1,3 @@
from src.infrastructure.database.repositories.user_repository import UserRepository
from src.infrastructure.database.repositories.legal_entity_repository import LegalEntityRepository
from src.infrastructure.database.repositories.purchase_request_repository import PurchaseRequestRepository

View File

@@ -0,0 +1,49 @@
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from src.application.abstractions.repositories.i_legal_entity_repository import ILegalEntityRepository
from src.application.contracts import ILogger
from src.application.domain.entities.legal_entity import LegalEntityEntity
from src.application.domain.exceptions import InternalException
from src.infrastructure.database.models.legal_entity import LegalEntityModel
class LegalEntityRepository(ILegalEntityRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
@staticmethod
def _to_entity(model: LegalEntityModel) -> LegalEntityEntity:
return LegalEntityEntity(
id=model.id,
user_id=model.user_id,
name=model.name,
inn=model.inn,
status=model.status,
short_name=model.short_name,
ogrn=model.ogrn,
kpp=model.kpp,
legal_address=model.legal_address,
actual_address=model.actual_address,
bank_details=model.bank_details,
contact_person=model.contact_person,
contact_phone=model.contact_phone,
kyc_verified=model.kyc_verified,
kyc_verified_at=model.kyc_verified_at,
)
async def get_by_user_id(self, user_id: str) -> LegalEntityEntity | None:
try:
stmt = select(LegalEntityModel).where(LegalEntityModel.user_id == user_id)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
if model is None:
return None
return self._to_entity(model)
except SQLAlchemyError as exc:
self._logger.exception(str(exc))
raise InternalException(message=f'Database error: {exc}') from exc

View File

@@ -0,0 +1,115 @@
from __future__ import annotations
from decimal import Decimal
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from ulid import ULID
from src.application.abstractions.repositories import IPurchaseRequestRepository
from src.application.contracts import ILogger
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
from src.application.domain.exceptions import ApplicationException, InternalException, NotFoundException
from src.infrastructure.database.models import PurchaseRequestModel
class PurchaseRequestRepository(IPurchaseRequestRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
@staticmethod
def _to_entity(model: PurchaseRequestModel) -> PurchaseRequestEntity:
return PurchaseRequestEntity(
id=model.id,
organization_id=model.organization_id,
status=model.status,
usdt_amount=model.usdt_amount,
rub_amount=model.rub_amount,
exchange_rate=model.exchange_rate,
service_fee_percent=model.service_fee_percent,
comment=model.comment,
admin_comment=model.admin_comment,
target_wallet_chain=model.target_wallet_chain,
target_wallet_address=model.target_wallet_address,
tx_hash=model.tx_hash,
created_at=model.created_at,
updated_at=model.updated_at,
completed_at=model.completed_at,
)
def _apply_filters(self, stmt, *, organization_id: str, status: str | None):
stmt = stmt.where(PurchaseRequestModel.organization_id == organization_id)
if status:
stmt = stmt.where(PurchaseRequestModel.status == status)
return stmt
async def create(
self,
*,
organization_id: str,
usdt_amount: Decimal,
comment: str | None,
target_wallet_chain: str | None,
target_wallet_address: str | None,
) -> PurchaseRequestEntity:
try:
model = PurchaseRequestModel(
id=str(ULID()),
organization_id=organization_id,
status='submitted',
usdt_amount=usdt_amount,
comment=comment,
target_wallet_chain=target_wallet_chain or 'ETH',
target_wallet_address=target_wallet_address,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return self._to_entity(model)
except SQLAlchemyError as exc:
self._logger.exception(str(exc))
raise InternalException(message=f'Database error: {exc}') from exc
async def get_by_id(self, request_id: str) -> PurchaseRequestEntity:
try:
res = await self._session.execute(
select(PurchaseRequestModel).where(PurchaseRequestModel.id == request_id)
)
model = res.scalar_one_or_none()
if model is None:
raise NotFoundException(message='Purchase request not found')
return self._to_entity(model)
except ApplicationException:
raise
except SQLAlchemyError as exc:
self._logger.exception(str(exc))
raise InternalException(message=f'Database error: {exc}') from exc
async def list_by_organization(
self,
*,
organization_id: str,
status: str | None,
limit: int,
offset: int,
) -> list[PurchaseRequestEntity]:
try:
stmt = select(PurchaseRequestModel).order_by(PurchaseRequestModel.created_at.desc())
stmt = self._apply_filters(stmt, organization_id=organization_id, status=status)
res = await self._session.execute(stmt.limit(limit).offset(offset))
return [self._to_entity(model) for model in res.scalars().all()]
except SQLAlchemyError as exc:
self._logger.exception(str(exc))
raise InternalException(message=f'Database error: {exc}') from exc
async def count_by_organization(self, *, organization_id: str, status: str | None) -> int:
try:
stmt = select(func.count()).select_from(PurchaseRequestModel)
stmt = self._apply_filters(stmt, organization_id=organization_id, status=status)
res = await self._session.execute(stmt)
return int(res.scalar_one())
except SQLAlchemyError as exc:
self._logger.exception(str(exc))
raise InternalException(message=f'Database error: {exc}') from exc

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.application.contracts import ILogger
from src.application.domain.entities import SessionEntity
from src.application.abstractions.repositories import ISessionRepository
from src.infrastructure.database.models import Session
class SessionRepository(ISessionRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
res = await self._session.execute(select(Session).where(Session.sid == sid))
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def upsert_by_device(
self,
*,
user_id: str,
device_id: str,
sid: str,
refresh_jti_hash: str,
refresh_expires_at: datetime,
user_agent: str | None,
ip: str | None,
now: datetime,
) -> SessionEntity:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
m = Session(
sid=sid,
user_id=user_id,
device_id=device_id,
revoked_at=None,
last_seen_at=now,
refresh_jti_hash=refresh_jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
first_ip=ip,
last_ip=ip,
)
self._session.add(m)
await self._session.flush()
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
else:
m.sid = sid
m.revoked_at = None
m.last_seen_at = now
m.refresh_jti_hash = refresh_jti_hash
m.refresh_expires_at = refresh_expires_at
m.user_agent = user_agent
m.last_ip = ip
await self._session.flush()
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
# Интерфейс требует None -> просто делаем update и flush
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(revoked_at=now)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh(
self,
sid: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> None:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
values = {'last_seen_at': now}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh_if_match(
self,
*,
sid: str,
old_jti_hash: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> bool:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
res = await self._session.execute(
update(Session)
.where(
Session.sid == sid,
Session.revoked_at.is_(None),
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
)
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
return (res.rowcount or 0) > 0

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException, InternalException, NotFoundException
from src.application.abstractions.repositories import IUserRepository
from src.application.domain.entities import UserEntity
from src.infrastructure.database.models import UserModel
class UserRepository(IUserRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def _get_active_user(self, user_id: str) -> UserModel:
stmt = (
select(UserModel)
.where(
UserModel.id == user_id,
UserModel.is_deleted.is_(False),
)
)
result = await self._session.execute(stmt)
user: UserModel | None = result.scalar_one_or_none()
if user is None:
self._logger.warning(f'User not found with user_id {user_id}')
raise NotFoundException(message='User not found')
return user
@staticmethod
def _to_entity(user: UserModel) -> UserEntity:
return UserEntity(
id=user.id,
email=user.email,
account_type=user.account_type,
first_name=user.first_name,
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
phone=user.phone,
kyc_verified_at=user.kyc_verified_at,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted,
created_at=user.created_at,
updated_at=user.updated_at,
)
async def get_user_by_id(self, user_id: str) -> UserEntity:
try:
user = await self._get_active_user(user_id)
return self._to_entity(user)
except ApplicationException:
raise
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise InternalException(message=f'Database error: {str(exception)}')

View File

@@ -0,0 +1,61 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import (
IUserRepository,
ILegalEntityRepository,
IPurchaseRequestRepository,
)
from src.application.contracts import ILogger
from src.infrastructure.database.repositories import (
UserRepository,
LegalEntityRepository,
PurchaseRequestRepository,
)
class UnitOfWork(IUnitOfWork):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
self.session_factory = session_factory
self._session: AsyncSession = None
self._user_repository: IUserRepository = None
self._legal_entity_repository: ILegalEntityRepository = None
self._purchase_request_repository: IPurchaseRequestRepository = None
self._logger: ILogger = logger
async def __aenter__(self):
self._logger.debug('UnitOfWork enter')
self._user_repository = None
self._legal_entity_repository = None
self._purchase_request_repository = None
self._session = self.session_factory()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._logger.error(f'UnitOfWork rollback_on_error exc_type={exc_type.__name__} exc_val={exc_val!r}')
await self._session.rollback()
self._logger.debug(f'UnitOfWork session rollback done exc_type={exc_type.__name__}')
else:
await self._session.flush()
await self._session.commit()
self._logger.debug('UnitOfWork commit')
await self._session.close()
self._logger.debug('UnitOfWork exit session closed')
@property
def user_repository(self) -> IUserRepository:
if self._user_repository is None:
self._user_repository = UserRepository(session=self._session, logger=self._logger)
return self._user_repository
@property
def legal_entity_repository(self) -> ILegalEntityRepository:
if self._legal_entity_repository is None:
self._legal_entity_repository = LegalEntityRepository(session=self._session, logger=self._logger)
return self._legal_entity_repository
@property
def purchase_request_repository(self) -> IPurchaseRequestRepository:
if self._purchase_request_repository is None:
self._purchase_request_repository = PurchaseRequestRepository(session=self._session, logger=self._logger)
return self._purchase_request_repository

View File

@@ -0,0 +1,28 @@
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat
from src.application.domain.enums import LogLevel
from src.infrastructure.config.settings import settings
from src.infrastructure.logger.logger import Logger
log_levels = {
'DEBUG': LogLevel.DEBUG,
'INFO': LogLevel.INFO,
'WARNING': LogLevel.WARNING,
'ERROR': LogLevel.ERROR,
'CRITICAL': LogLevel.CRITICAL,
'EXCEPTION': LogLevel.EXCEPTION,
}
log_formats = {
'JSON': LogFormat.JSON,
'TEXT': LogFormat.TEXT,
}
logger = Logger(
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
)
def get_logger() -> ILogger:
return logger

View File

@@ -0,0 +1,129 @@
import traceback
import inspect
import sys
import json
from datetime import datetime
from typing import Callable, Optional, Any
from ulid import ULID
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat, LogLevel
from src.infrastructure.context_vars import trace_id_var
class Logger(ILogger):
_instance = None
__default_format = LogFormat.JSON
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
if cls._instance is None:
cls._instance = super(Logger, cls).__new__(cls)
return cls._instance
def __init__(
self,
log_format: LogFormat = __default_format,
min_level: LogLevel = LogLevel.INFO,
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
instance_id: str = "N/A",
):
self.log_format = log_format
self.min_level = min_level
self.id_generator = id_generator
self.instance_id = instance_id
def set_instance_id(self, instance_id: str) -> None:
self.instance_id = instance_id
def get_instance_id(self) -> str:
return self.instance_id
def set_format(self, log_format: LogFormat) -> None:
if not isinstance(log_format, LogFormat):
raise ValueError("Log format must be an instance of LogFormat enum")
self.log_format = log_format
def set_min_level(self, level: LogLevel) -> None:
self.min_level = level
def new_trace_id(self) -> str:
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
trace_id_var.set(trace_id)
return trace_id
def set_trace_id(self, trace_id: str) -> None:
trace_id_var.set(trace_id)
def get_trace_id(self) -> str:
return trace_id_var.get()
def clear_trace_id(self) -> None:
trace_id_var.set("N/A")
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
current_frame = inspect.currentframe()
if (
current_frame
and current_frame.f_back
and current_frame.f_back.f_back
and current_frame.f_back.f_back.f_back
):
frame = current_frame.f_back.f_back.f_back
filename = frame.f_code.co_filename
line_number = frame.f_lineno
else:
filename = "unknown"
line_number = 0
log_data = {
"timestamp": datetime.now().isoformat(),
"level": level.name,
"instance_id": self.instance_id,
"file": filename,
"line": line_number,
"trace_id": trace_id_var.get(),
"message": message,
}
if level == LogLevel.EXCEPTION:
log_data["exception"] = traceback.format_exc()
return log_data
def _log(self, level: LogLevel, message: str) -> None:
if level >= self.min_level:
log_data = self._prepare_log_data(level, message)
if self.log_format == LogFormat.JSON:
log_message = json.dumps(log_data, ensure_ascii=False)
else:
log_message = (
f"{log_data['timestamp']} - {log_data['level']} - "
f"{log_data['instance_id']} - {log_data['trace_id']} - "
f"{log_data['file']}:{log_data['line']} - "
f"{log_data['message']}"
)
if "exception" in log_data:
log_message += f"\nTraceback:\n{log_data['exception']}"
self._write(log_message)
def _write(self, message: str) -> None:
sys.stdout.write(message + "\n")
def debug(self, message: str) -> None:
self._log(LogLevel.DEBUG, message)
def info(self, message: str) -> None:
self._log(LogLevel.INFO, message)
def warning(self, message: str) -> None:
self._log(LogLevel.WARNING, message)
def error(self, message: str) -> None:
self._log(LogLevel.ERROR, message)
def critical(self, message: str) -> None:
self._log(LogLevel.CRITICAL, message)
def exception(self, message: str) -> None:
self._log(LogLevel.EXCEPTION, message)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from io import BytesIO
from PIL import Image
def image_bytes_to_webp(raw: bytes, *, quality: int = 82) -> bytes:
im = Image.open(BytesIO(raw))
if im.mode == 'P':
im = im.convert('RGBA')
elif im.mode == 'LA':
im = im.convert('RGBA')
elif im.mode not in ('RGBA', 'RGB'):
im = im.convert('RGB')
out = BytesIO()
im.save(out, format='WEBP', quality=quality)
return out.getvalue()

View File

@@ -0,0 +1 @@
from src.infrastructure.messanger.rabbit_client import RabbitClient

View File

@@ -0,0 +1,72 @@
from typing import Any, Mapping
from faststream.rabbit import RabbitBroker
from src.application.contracts import IQueueMessanger
from src.infrastructure.config import settings
class RabbitClient(IQueueMessanger):
def __init__(self) -> None:
self._broker = RabbitBroker(
settings.RABBIT_URL,
)
self._connected = False
async def connect(self) -> None:
if self._connected:
return
await self._broker.connect()
self._connected = True
async def close(self) -> None:
if not self._connected:
return
await self._broker.close()
self._connected = False
async def _ensure_connected(self) -> None:
if not self._connected:
await self.connect()
async def publish_to_queue(
self,
queue: str,
message: Any,
*,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
queue=queue,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)
async def publish(
self,
message: Any,
*,
exchange: str,
routing_key: str,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
exchange=exchange,
routing_key=routing_key,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)

View File

@@ -0,0 +1,3 @@
from src.infrastructure.security.jwt import JwtService
from src.infrastructure.security.csrf import CsrfService
from src.infrastructure.security.hash import HashService

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import secrets
from typing import Any, Optional, Mapping
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from src.application.contracts import ICsrfService
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
class CsrfService(ICsrfService):
COOKIE_NAME = 'csrf_token'
HEADER_NAME = 'X-CSRF-Token'
SALT = 'csrf'
TTL_SECONDS = 3600
def __init__(self) -> None:
self._serializer = URLSafeTimedSerializer(
secret_key=settings.CSRF_SECRET_KEY,
salt=self.SALT,
)
@property
def cookie_name(self) -> str:
return self.COOKIE_NAME
@property
def header_name(self) -> str:
return self.HEADER_NAME
@property
def ttl_seconds(self) -> int:
return self.TTL_SECONDS
def issue(self, subject: Optional[str] = None) -> str:
payload = {
'sub': subject,
'nonce': secrets.token_urlsafe(32),
}
return self._serializer.dumps(payload)
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
try:
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
except SignatureExpired:
raise ApplicationException(
status_code=403,
message='CSRF token expired',
)
except BadSignature:
raise ApplicationException(
status_code=403,
message='CSRF token invalid',
)
if expected_subject is not None and data.get('sub') != expected_subject:
raise ApplicationException(
status_code=403,
message='CSRF token subject mismatch',
)
return data
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
cookie_token = cookies.get(self.COOKIE_NAME)
header_token = headers.get(self.HEADER_NAME)
return cookie_token, header_token
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
if not cookie_token or not header_token:
raise ApplicationException(
status_code=403,
message='CSRF token missing',
)
if not secrets.compare_digest(cookie_token, header_token):
raise ApplicationException(
status_code=403,
message='CSRF token mismatch',
)
self.verify(cookie_token, expected_subject=expected_subject)

View File

@@ -0,0 +1,17 @@
import bcrypt
from src.application.contracts import IHashService, ILogger
class HashService(IHashService):
def __init__(self, logger: ILogger):
self._logger = logger
async def hash(self, value: str) -> str:
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
self._logger.info(f'Hash value {hashed_value.decode()}')
return hashed_value.decode()
async def verify(self, hashed_value: str, plain_value: str) -> bool:
self._logger.info(f'Hash value {hashed_value[:10]}')
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
from jose import jwt, ExpiredSignatureError, JWTError
from src.application.contracts import ILogger, IJwtService
from src.application.domain.dto import AccessTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
from src.infrastructure.vault import JwtKeyStore
class JwtService(IJwtService):
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
self._logger = logger
self._key_store = key_store
async def decode_access_token(self, token: str) -> AccessTokenPayload:
payload = await self._decode_and_verify(token)
if payload.get('type') != 'access':
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
raise ApplicationException(status_code=401, message='Invalid token type')
try:
return AccessTokenPayload(
sub=str(payload['sub']),
type='access',
sid=str(payload['sid']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
iss=payload.get('iss'),
aud=payload.get('aud'),
)
except KeyError as exception:
self._logger.warning(f'Access token missing claim error={str(exception)}')
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
async def _decode_and_verify(self, token: str) -> dict:
kid: str | None = None
try:
header = jwt.get_unverified_header(token)
kid = header.get('kid')
if not kid:
self._logger.warning(f'JWT header missing kid header={header}')
raise ApplicationException(status_code=401, message='Missing token header: kid')
received_alg = header.get('alg')
if received_alg != settings.JWT_ALGORITHM:
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
raise ApplicationException(status_code=401, message='Invalid token algorithm')
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
await self._key_store.refresh()
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.warning(f'JWT unknown kid kid={kid}')
raise ApplicationException(status_code=401, message='Unknown token kid')
options = {
'verify_signature': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iat': True,
'verify_aud': bool(settings.JWT_AUDIENCE),
'verify_iss': bool(settings.JWT_ISSUER),
'require_exp': True,
'require_iat': True,
'require_nbf': True,
'require_sub': True,
'leeway': 10,
}
payload = jwt.decode(
token,
public_pem,
algorithms=[settings.JWT_ALGORITHM],
audience=settings.JWT_AUDIENCE or None,
issuer=settings.JWT_ISSUER or None,
options=options,
)
if 'sid' not in payload:
self._logger.warning(f'JWT missing sid claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: sid')
if 'type' not in payload:
self._logger.warning(f'JWT missing type claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: type')
return payload
except ExpiredSignatureError as exception:
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Token expired')
except ApplicationException:
raise
except JWTError as exception:
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Invalid token')
except Exception as exception:
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
raise ApplicationException(status_code=500, message='JWT decode failed')

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
from aiobotocore.session import get_session
class S3Service:
def __init__(
self,
*,
bucket: str,
region: str,
access_key_id: str | None,
secret_access_key: str | None,
public_base_url: str | None,
endpoint_url: str | None,
use_reg_ru_website_public_host: bool,
):
self._bucket = bucket
self._region = region or 'us-east-1'
self._access_key_id = access_key_id
self._secret_access_key = secret_access_key
pb = (public_base_url or '').strip().rstrip('/')
self._public_base_url = pb if pb else None
self._endpoint_url = endpoint_url.strip().rstrip('/') if endpoint_url and endpoint_url.strip() else None
self._use_reg_ru_website_public_host = use_reg_ru_website_public_host
@staticmethod
def _url_prefix_variants(prefix: str) -> list[str]:
p = prefix.rstrip('/') + '/'
out = [p]
if p.startswith('https://'):
out.append('http://' + p[8:])
elif p.startswith('http://'):
out.append('https://' + p[7:])
return out
def _public_url_prefixes(self) -> list[str]:
acc: list[str] = []
pb = self._public_base_url
if pb:
acc.extend(self._url_prefix_variants(pb))
ep = self._endpoint_url
if ep:
base = f'{ep.rstrip("/")}/{self._bucket}'
acc.extend(self._url_prefix_variants(base))
if ep and self._use_reg_ru_website_public_host and 's3.regru.cloud' in ep.lower():
wh = f'https://{self._bucket}.website.regru.cloud'
acc.extend(self._url_prefix_variants(wh))
if not ep:
if self._region == 'us-east-1':
h = f'https://{self._bucket}.s3.amazonaws.com'
else:
h = f'https://{self._bucket}.s3.{self._region}.amazonaws.com'
acc.extend(self._url_prefix_variants(h))
seen: set[str] = set()
uniq: list[str] = []
for x in sorted(acc, key=len, reverse=True):
if x not in seen:
seen.add(x)
uniq.append(x)
return uniq
def object_key_from_public_url(self, url: str) -> str | None:
u = (url or '').strip()
if not u:
return None
for p in self._public_url_prefixes():
if u.startswith(p):
k = u[len(p):].split('?', 1)[0].split('#', 1)[0]
return k if k else None
return None
def _object_url(self, key: str) -> str:
if self._public_base_url:
return f'{self._public_base_url}/{key}'
endpoint = self._endpoint_url
if endpoint:
if (
self._use_reg_ru_website_public_host
and 's3.regru.cloud' in endpoint.lower()
):
return f'https://{self._bucket}.website.regru.cloud/{key}'
return f'{endpoint}/{self._bucket}/{key}'
region = self._region
if region == 'us-east-1':
host = 's3.amazonaws.com'
else:
host = f's3.{region}.amazonaws.com'
return f'https://{self._bucket}.{host}/{key}'
async def upload_bytes(self, *, key: str, body: bytes, content_type: str) -> str:
session = get_session()
kw: dict[str, object] = {'region_name': self._region}
aid = self._access_key_id
sk = self._secret_access_key
ep = self._endpoint_url
if aid:
kw['aws_access_key_id'] = aid
if sk:
kw['aws_secret_access_key'] = sk
if ep:
kw['endpoint_url'] = ep
async with session.create_client('s3', **kw) as client:
await client.put_object(
Bucket=self._bucket,
Key=key,
Body=body,
ContentType=content_type,
)
return self._object_url(key)
async def delete_object(self, *, key: str) -> None:
session = get_session()
kw: dict[str, object] = {'region_name': self._region}
aid = self._access_key_id
sk = self._secret_access_key
ep = self._endpoint_url
if aid:
kw['aws_access_key_id'] = aid
if sk:
kw['aws_secret_access_key'] = sk
if ep:
kw['endpoint_url'] = ep
async with session.create_client('s3', **kw) as client:
await client.delete_object(Bucket=self._bucket, Key=key)

View File

@@ -0,0 +1 @@
from src.infrastructure.utils.instance_id import generate_instance_id

View File

@@ -0,0 +1,14 @@
from ulid import ULID
def generate_instance_id() -> str:
"""
Generate a process-wide instance id in ULID format.
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
"""
return str(ULID())

View File

@@ -0,0 +1,4 @@
from src.infrastructure.vault.client import VaultClient
from src.infrastructure.vault.utils import create_hvac_client, read_kv2_secret
from src.infrastructure.vault.keys import JwtKeyStore
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from typing import Any
import hvac
def _vault_token_renew_failed(exception: Exception) -> bool:
if isinstance(exception, (hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized)):
return True
message = getattr(exception, 'message', None) or str(exception)
if isinstance(message, str):
lower = message.lower()
return 'permission denied' in lower or 'invalid token' in lower or '403' in lower
return False
class VaultClient:
def __init__(
self,
*,
addr: str,
role_id: str,
secret_id: str,
namespace: str | None,
mount_point: str,
) -> None:
self._mount_point = mount_point
self._addr = addr
self._role_id = role_id
self._secret_id = secret_id
self._namespace = namespace
self._client = hvac.Client(url=addr, namespace=namespace)
self._approle_login()
def _approle_login(self) -> None:
self._client.auth.approle.login(role_id=self._role_id, secret_id=self._secret_id)
def _renew_or_login(self) -> None:
try:
self._client.auth.token.renew_self()
except Exception:
self._approle_login()
def read_secret(self, path: str) -> dict[str, Any]:
for attempt in range(2):
try:
secret = self._client.secrets.kv.v2.read_secret_version(
path=path,
mount_point=self._mount_point,
)
return dict(secret.get('data', {}).get('data', {}))
except Exception as exc:
if attempt == 0 and _vault_token_renew_failed(exc):
self._renew_or_login()
continue
raise
def read_secret_optional(self, path: str) -> dict[str, Any]:
if not path:
return {}
try:
return self.read_secret(path)
except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized):
return {}

View File

@@ -0,0 +1,111 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.vault.client import VaultClient
class JwtKeyStore:
_instance: 'JwtKeyStore | None' = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
*,
vault_addr: str,
vault_role_id: str,
vault_secret_id: str,
vault_namespace: str | None,
mount_point: str,
kid_path: str = 'jwt/kid',
kids_prefix: str = 'jwt/kids',
refresh_ttl_seconds: int = 60,
):
if getattr(self, '_initialized', False):
return
self._vault_client = VaultClient(
addr=vault_addr,
role_id=vault_role_id,
secret_id=vault_secret_id,
namespace=vault_namespace,
mount_point=mount_point,
)
self._kid_path = kid_path
self._kids_prefix = kids_prefix
self._refresh_ttl_seconds = refresh_ttl_seconds
self._lock = asyncio.Lock()
self._keyset: JwtPublicKeySet | None = None
self._last_refresh_at: datetime | None = None
self._initialized = True
@classmethod
def get_instance(cls) -> 'JwtKeyStore':
if cls._instance is None:
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
return cls._instance
def _read_keyset_sync(self) -> JwtPublicKeySet:
kids = self._vault_client.read_secret(self._kid_path)
active_kid = kids.get('active')
previous_kid = kids.get('previous')
if not active_kid:
raise RuntimeError('Vault jwt/kid secret missing "active"')
active = self._read_public_key_sync(str(active_kid))
previous = None
if previous_kid and previous_kid != active_kid:
previous = self._read_public_key_sync(str(previous_kid))
return JwtPublicKeySet(active=active, previous=previous)
def _read_public_key_sync(self, kid: str) -> JwtPublicKey:
data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
pub = data.get('public_key')
if not pub:
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
return JwtPublicKey(kid=kid, public_key_pem=pub)
async def refresh(self) -> JwtPublicKeySet:
keyset = await asyncio.to_thread(self._read_keyset_sync)
async with self._lock:
self._keyset = keyset
self._last_refresh_at = datetime.now(timezone.utc)
return keyset
async def get_public_key_for_kid(self, kid: str) -> str | None:
ks = await self._get_or_refresh()
return ks.public_keys_by_kid().get(kid)
async def last_refresh_at(self) -> datetime | None:
async with self._lock:
return self._last_refresh_at
async def _get_or_refresh(self) -> JwtPublicKeySet:
async with self._lock:
ks = self._keyset
last = self._last_refresh_at
if ks is None:
return await self.refresh()
if last is None:
return await self.refresh()
age = (datetime.now(timezone.utc) - last).total_seconds()
if age >= self._refresh_ttl_seconds:
return await self.refresh()
return ks

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from src.infrastructure.vault import JwtKeyStore
logger = logging.getLogger(__name__)
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
scheduler = AsyncIOScheduler()
scheduler.add_job(
store.refresh,
trigger=IntervalTrigger(seconds=refresh_seconds),
id="jwt_keys_refresh",
replace_existing=True,
max_instances=1,
coalesce=True,
misfire_grace_time=60,
)
scheduler.start()
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
return scheduler

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
import hvac
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
client = hvac.Client(url=url, token=token, timeout=timeout)
if not client.is_authenticated():
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
return client
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
secret = client.secrets.kv.v2.read_secret_version(
mount_point=mount_point,
path=path,
)
return secret["data"]["data"]

140
src/main.py Normal file
View File

@@ -0,0 +1,140 @@
from __future__ import annotations
from contextlib import asynccontextmanager
import secrets
from typing import AsyncGenerator
from fastapi import Depends, FastAPI
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException
from fastapi.exceptions import RequestValidationError
from src.application.domain.exceptions import ApplicationException, UnauthorizedException
from src.infrastructure.cache import create_redis_client
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
from src.infrastructure.utils import generate_instance_id
from src.infrastructure.logger import logger
from src.infrastructure.config import settings
from src.presentation.handlers import (
application_exception_handler,
http_exception_handler,
unhandled_exception_handler,
validation_exception_handler,
)
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
from src.presentation.routing import purchase_requests_router
security = HTTPBasic()
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
if not (user_ok and pass_ok):
raise UnauthorizedException(
message='Unauthorized',
headers={'WWW-Authenticate': 'Basic'},
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
instance_id = generate_instance_id()
logger.set_instance_id(instance_id)
logger.info(f'B2B service instance started with id {instance_id}')
app.state.redis = create_redis_client()
if not settings.VAULT_ROLE_ID.strip() or not settings.VAULT_SECRET_ID.strip():
raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID must be set')
jwt_store = JwtKeyStore(
vault_addr=settings.VAULT_ADDR,
vault_role_id=settings.VAULT_ROLE_ID,
vault_secret_id=settings.VAULT_SECRET_ID,
vault_namespace=settings.VAULT_NAMESPACE,
mount_point=settings.VAULT_MOUNT_POINT,
kid_path=settings.VAULT_JWT_KID_PATH,
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
)
await jwt_store.refresh()
jwt_scheduler = start_jwt_keys_scheduler(jwt_store, refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS)
app.state.jwt_key_store = jwt_store
app.state.jwt_keys_scheduler = jwt_scheduler
yield
await app.state.redis.aclose()
logger.info(f'B2B service instance ended with id {instance_id}')
app: FastAPI = FastAPI(
redoc_url=None,
docs_url=None,
lifespan=lifespan,
title='B2B Service',
version='1.0.0',
description='Purchase requests API for legal entity client users.',
license_info={
'name': 'MIT',
'url': 'https://opensource.org/licenses/MIT',
},
)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(ApplicationException, application_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
app.include_router(purchase_requests_router)
app.add_middleware(TraceIDMiddleware, logger=logger)
app.add_middleware(
SecurityHeadersMiddleware,
hsts=True,
hsts_preload=False,
frame_options='DENY',
referrer_policy='strict-origin-when-cross-origin',
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
)
app.add_middleware(
CORSMiddleware,
allow_origins=[],
allow_origin_regex='.*',
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
@app.get('/docs', include_in_schema=False)
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
'''Custom Swagger documentation, optionally protected with basic authentication.'''
return get_swagger_ui_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
)
@app.get('/redoc', include_in_schema=False)
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
return get_redoc_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
)
@app.post('/ping')
async def ping() -> dict[str, str]:
return {
'message': 'pong',
'status': 'ok',
}

View File

@@ -0,0 +1,4 @@
from src.presentation.decorators.csrf import csrf_protect
from src.presentation.decorators.rate_limit import rate_limit, _email_rl_key as email_rl_key
from src.presentation.decorators.auth import require_access_token
from src.presentation.decorators.cache import cached

View File

@@ -0,0 +1,36 @@
from fastapi import Depends, Request
from fastapi.security.utils import get_authorization_scheme_param
from src.application.contracts import IJwtService
from src.application.domain.exceptions import UnauthorizedException
from src.application.domain.dto import AccessTokenPayload, AuthContext
from src.presentation.dependencies import get_jwt_service
def _extract_access_token(request: Request) -> str | None:
token = request.cookies.get('access_token')
if token:
return token
auth = request.headers.get('Authorization')
if auth:
scheme, param = get_authorization_scheme_param(auth)
if scheme.lower() == 'bearer' and param:
return param
return None
async def require_access_token(
request: Request,
jwt_service: IJwtService = Depends(get_jwt_service),
) -> AuthContext:
token = _extract_access_token(request)
if not token:
raise UnauthorizedException(message='Not authenticated')
payload: AccessTokenPayload = await jwt_service.decode_access_token(token)
if payload.type != 'access':
raise UnauthorizedException(message='Invalid token type')
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import functools
from typing import Any, Awaitable, Callable
from fastapi import Request
from fastapi.responses import ORJSONResponse
from src.infrastructure.cache import KeydbCache
from src.infrastructure.logger import get_logger
from src.presentation.dependencies.cache import get_redis
def cached(*, prefix: str) -> Callable:
def decorator(func: Callable[..., Awaitable[Any]]):
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
logger = get_logger()
request = kwargs.get('request')
if not isinstance(request, Request):
for a in args:
if isinstance(a, Request):
request = a
break
auth = kwargs.get('auth')
user_id = getattr(auth, 'user_id', None) if auth else None
if request is None or user_id is None:
return await func(*args, **kwargs)
cache_key = f'{prefix}:{user_id}'
try:
redis = get_redis(request)
cache = KeydbCache(redis)
hit = await cache.get_user(user_id)
if hit is not None:
logger.debug(f'Cache hit key={cache_key}')
return ORJSONResponse(status_code=200, content=hit)
except Exception as e:
logger.warning(f'Cache read failed key={cache_key} error={e}')
return await func(*args, **kwargs)
return wrapper
return decorator

Some files were not shown because too many files have changed in this diff Show More