init
This commit is contained in:
145
.gitignore
vendored
Normal file
145
.gitignore
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
generate_password_hash.py
|
||||
# C extensions
|
||||
*.so
|
||||
*.pyd
|
||||
*.dll
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache/
|
||||
.pytest_cache/
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
|
||||
# Type checkers / linters
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
.pyre/
|
||||
.pytype/
|
||||
.ruff_cache/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env.*
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Poetry
|
||||
poetry.lock
|
||||
|
||||
# Pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# Hatch
|
||||
.hatch/
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Local databases
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
# Secrets / credentials
|
||||
secrets.json
|
||||
credentials.json
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
|
||||
# PyCharm / IntelliJ IDEA
|
||||
.idea/
|
||||
*.iml
|
||||
out/
|
||||
|
||||
# VS Code (optional)
|
||||
.vscode/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Sphinx docs
|
||||
docs/_build/
|
||||
|
||||
# mkdocs
|
||||
site/
|
||||
|
||||
# celery
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# mypy compiled cache
|
||||
.mypy_cache/
|
||||
|
||||
# pyinstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# pytest debug
|
||||
pytestdebug.log
|
||||
|
||||
# Local config overrides
|
||||
config.local.py
|
||||
settings.local.py
|
||||
|
||||
# Vault / local dev secrets
|
||||
.env.vault
|
||||
vault.token
|
||||
|
||||
.env
|
||||
.dockerignore
|
||||
/sql
|
||||
25
Dockerfile
Normal file
25
Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
COPY src ./src
|
||||
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS runtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
COPY --from=builder /app/src /app/src
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH" \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8001
|
||||
|
||||
CMD ["sh", "-c", "python -m granian --interface asgi ${APP_MODULE:-src.main:app} --host ${APP_HOST:-0.0.0.0} --port ${APP_PORT:-8001} --workers ${APP_WORKERS:-2} --loop uvloop"]
|
||||
57
docker-compose.yml
Normal file
57
docker-compose.yml
Normal file
@@ -0,0 +1,57 @@
|
||||
services:
|
||||
b2b:
|
||||
container_name: b2b-service
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8001:8001"
|
||||
environment:
|
||||
PYTHONUNBUFFERED: "1"
|
||||
APP_MODULE: "src.main:app"
|
||||
APP_HOST: "0.0.0.0"
|
||||
APP_PORT: "8001"
|
||||
APP_WORKERS: "2"
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
b2b_keydb:
|
||||
condition: service_healthy
|
||||
restart: no
|
||||
|
||||
b2b_keydb:
|
||||
image: eqalpha/keydb
|
||||
container_name: b2b_keydb
|
||||
restart: no
|
||||
expose:
|
||||
- "6379"
|
||||
volumes:
|
||||
- b2b_keydb_data:/data
|
||||
command:
|
||||
- keydb-server
|
||||
- --requirepass
|
||||
- ${REDIS_PASSWORD}
|
||||
- --dir
|
||||
- /data
|
||||
- --appendonly
|
||||
- "yes"
|
||||
- --appendfsync
|
||||
- everysec
|
||||
- --save
|
||||
- "900"
|
||||
- "1"
|
||||
- --save
|
||||
- "300"
|
||||
- "10"
|
||||
- --save
|
||||
- "60"
|
||||
- "10000"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "-a", "${REDIS_PASSWORD}", "ping"]
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
retries: 20
|
||||
|
||||
|
||||
volumes:
|
||||
b2b_keydb_data:
|
||||
20
pyproject.toml
Normal file
20
pyproject.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[project]
|
||||
name = "b2b-service"
|
||||
version = "0.1.0"
|
||||
description = "B2B purchase requests API for legal entity client users"
|
||||
requires-python = "==3.12.*"
|
||||
dependencies = [
|
||||
"apscheduler==3.11.2",
|
||||
"asyncpg==0.31.0",
|
||||
"dotenv==0.9.9",
|
||||
"fastapi==0.128.7",
|
||||
"pydantic-settings==2.12.0",
|
||||
"python-jose==3.5.0",
|
||||
"python-ulid==3.1.0",
|
||||
"sqlalchemy==2.0.46",
|
||||
"uvloop==0.22.1; platform_system != 'Windows'",
|
||||
"granian==2.6.1",
|
||||
"hvac==2.4.0",
|
||||
"redis==7.2.0",
|
||||
"orjson==3.11.7",
|
||||
]
|
||||
1
src/application/abstractions/__init__.py
Normal file
1
src/application/abstractions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.abstractions.i_unit_of_work import IUnitOfWork
|
||||
25
src/application/abstractions/i_unit_of_work.py
Normal file
25
src/application/abstractions/i_unit_of_work.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, runtime_checkable
|
||||
from src.application.abstractions.repositories import (
|
||||
IUserRepository,
|
||||
ILegalEntityRepository,
|
||||
IPurchaseRequestRepository,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IUnitOfWork(Protocol):
|
||||
async def __aenter__(self) -> "IUnitOfWork": ...
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository: ...
|
||||
|
||||
@property
|
||||
def legal_entity_repository(self) -> ILegalEntityRepository: ...
|
||||
|
||||
@property
|
||||
def purchase_request_repository(self) -> IPurchaseRequestRepository: ...
|
||||
3
src/application/abstractions/repositories/__init__.py
Normal file
3
src/application/abstractions/repositories/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.application.abstractions.repositories.i_user_repository import IUserRepository
|
||||
from src.application.abstractions.repositories.i_legal_entity_repository import ILegalEntityRepository
|
||||
from src.application.abstractions.repositories.i_purchase_request_repository import IPurchaseRequestRepository
|
||||
@@ -0,0 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from src.application.domain.entities.legal_entity import LegalEntityEntity
|
||||
|
||||
|
||||
class ILegalEntityRepository(ABC):
|
||||
@abstractmethod
|
||||
async def get_by_user_id(self, user_id: str) -> LegalEntityEntity | None:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,37 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
|
||||
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
|
||||
|
||||
|
||||
class IPurchaseRequestRepository(ABC):
|
||||
@abstractmethod
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
usdt_amount: Decimal,
|
||||
comment: str | None,
|
||||
target_wallet_chain: str | None,
|
||||
target_wallet_address: str | None,
|
||||
) -> PurchaseRequestEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_id(self, request_id: str) -> PurchaseRequestEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def list_by_organization(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
status: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> list[PurchaseRequestEntity]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def count_by_organization(self, *, organization_id: str, status: str | None) -> int:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from src.application.domain.entities import SessionEntity
|
||||
|
||||
|
||||
|
||||
class ISessionRepository(ABC):
|
||||
@abstractmethod
|
||||
async def get_by_sid(self, sid: str) -> SessionEntity | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> SessionEntity | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from src.application.domain.entities import UserEntity
|
||||
|
||||
|
||||
class IUserRepository(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
5
src/application/commands/__init__.py
Normal file
5
src/application/commands/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.application.commands.purchase_request_commands import (
|
||||
CreatePurchaseRequestCommand,
|
||||
GetPurchaseRequestCommand,
|
||||
ListPurchaseRequestsCommand,
|
||||
)
|
||||
63
src/application/commands/change_email_complete.py
Normal file
63
src/application/commands/change_email_complete.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ChangeEmailCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, *, user_id: str, code: str) -> bool:
|
||||
code = (code or '').strip()
|
||||
|
||||
NEW_USER_PREFIX = 'change_email:new_user:'
|
||||
NEW_CODE_PREFIX = 'change_email:new_code:'
|
||||
|
||||
new_user_key = f'{NEW_USER_PREFIX}{user_id}'
|
||||
new_code_key = f'{NEW_CODE_PREFIX}{code}'
|
||||
|
||||
cached_user_id = await self._cache.get(new_code_key)
|
||||
if not cached_user_id:
|
||||
self._logger.info(f'Change email complete failed: code not found (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_user_id != user_id:
|
||||
self._logger.info(f'Change email complete failed: code-user mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
raw_value = await self._cache.get(new_user_key)
|
||||
if not raw_value:
|
||||
self._logger.info(f'Change email complete failed: user key missing (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
separator_idx = raw_value.index(':')
|
||||
code_hash = raw_value[:separator_idx]
|
||||
new_email = raw_value[separator_idx + 1:]
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Change email complete failed: code hash mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
user = await self._unit_of_work.user_repository.set_email(user_id=user_id, email=new_email)
|
||||
await self._cache.set_user(user_id, user)
|
||||
|
||||
try:
|
||||
await self._cache.delete(new_code_key)
|
||||
await self._cache.delete(new_user_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Change email complete cleanup failed (user_id={user_id}): {e}')
|
||||
|
||||
self._logger.info(f'Email changed for user_id={user_id}')
|
||||
return True
|
||||
145
src/application/commands/change_email_confirm_old.py
Normal file
145
src/application/commands/change_email_confirm_old.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ChangeEmailConfirmOldCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, *, user_id: str, code: str, new_email: str) -> bool:
|
||||
TTL = 300
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
OLD_USER_PREFIX = 'change_email:old_user:'
|
||||
OLD_CODE_PREFIX = 'change_email:old_code:'
|
||||
NEW_USER_PREFIX = 'change_email:new_user:'
|
||||
NEW_CODE_PREFIX = 'change_email:new_code:'
|
||||
|
||||
code = (code or '').strip()
|
||||
old_user_key = f'{OLD_USER_PREFIX}{user_id}'
|
||||
old_code_key = f'{OLD_CODE_PREFIX}{code}'
|
||||
|
||||
cached_user_id = await self._cache.get(old_code_key)
|
||||
if not cached_user_id:
|
||||
self._logger.info(f'Change email confirm-old failed: code not found (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_user_id != user_id:
|
||||
self._logger.info(f'Change email confirm-old failed: code-user mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(old_user_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Change email confirm-old failed: user key missing (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Change email confirm-old failed: code hash mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
|
||||
if user.email and user.email.lower() == new_email.lower():
|
||||
self._logger.info(f'Change email confirm-old failed: new email same as current (user_id={user_id})')
|
||||
raise ApplicationException(400, 'New email must differ from the current one')
|
||||
|
||||
email_taken = await self._unit_of_work.user_repository.email_exists(email=new_email)
|
||||
if email_taken:
|
||||
self._logger.info(f'Change email confirm-old failed: new email already taken (user_id={user_id})')
|
||||
raise ApplicationException(409, 'Email already in use')
|
||||
|
||||
try:
|
||||
await self._cache.delete(old_code_key)
|
||||
await self._cache.delete(old_user_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Change email confirm-old cleanup failed (user_id={user_id}): {e}')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
new_user_key = f'{NEW_USER_PREFIX}{user_id}'
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
new_code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
new_code_key = f'{NEW_CODE_PREFIX}{new_code}'
|
||||
|
||||
new_code_hash = await self._hash_service.hash(new_code)
|
||||
|
||||
reserved = await self._cache.set_nx(new_code_key, user_id, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(new_user_key, f'{new_code_hash}:{new_email}', ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(new_code_key)
|
||||
self._logger.error(f'Change email confirm-old failed: cannot save new code hash for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': new_email,
|
||||
'code': new_code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'change_email_new',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Change email new code created for user_id={user_id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(new_user_key)
|
||||
await self._cache.delete(new_code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish change email new code for user_id={user_id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Change email confirm-old failed: code space exhausted for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
126
src/application/commands/change_email_start.py
Normal file
126
src/application/commands/change_email_start.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ChangeEmailStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
USER_PREFIX = 'change_email:old_user:'
|
||||
CODE_PREFIX = 'change_email:old_code:'
|
||||
LOCK_PREFIX = 'change_email:lock:'
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
|
||||
if not user.email:
|
||||
self._logger.warning(f'User {user_id} does not have an email address')
|
||||
raise ApplicationException(404, f'User {user_id} does not have an email address')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{user_id}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Change email throttled by lock (user_id={user_id})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
|
||||
existing = await self._cache.get(user_key)
|
||||
if existing:
|
||||
self._logger.info(f'Change email denied: code already exists for user_id={user_id}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Change email failed: cannot save code hash for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': user.email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'change_email_old',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Change email old code created for user_id={user_id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(user_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish change email old code for user_id={user_id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Change email failed: code space exhausted for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
81
src/application/commands/change_password_complete.py
Normal file
81
src/application/commands/change_password_complete.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ChangePasswordCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
code: str,
|
||||
new_password: str,
|
||||
confirm_password: str,
|
||||
) -> bool:
|
||||
code = (code or '').strip()
|
||||
|
||||
USER_PREFIX = 'change_password:user:'
|
||||
CODE_PREFIX = 'change_password:code:'
|
||||
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
if new_password != confirm_password:
|
||||
self._logger.info(f'Change password failed: passwords do not match (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Passwords do not match')
|
||||
|
||||
cached_user_id = await self._cache.get(code_key)
|
||||
if not cached_user_id:
|
||||
self._logger.info(f'Change password failed: code not found (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_user_id != user_id:
|
||||
self._logger.info(f'Change password failed: code-user mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(user_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Change password failed: user key missing (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Change password failed: code hash mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
current_password_hash = await self._unit_of_work.user_repository.get_password_hash(user_id=user_id)
|
||||
|
||||
is_same = await self._hash_service.verify(hashed_value=current_password_hash, plain_value=new_password)
|
||||
if is_same:
|
||||
self._logger.info(f'Change password failed: new password same as current (user_id={user_id})')
|
||||
raise ApplicationException(400, 'New password must differ from the current one')
|
||||
|
||||
new_password_hash = await self._hash_service.hash(new_password)
|
||||
user = await self._unit_of_work.user_repository.set_password(
|
||||
user_id=user_id,
|
||||
password_hash=new_password_hash,
|
||||
)
|
||||
await self._cache.set_user(user_id, user)
|
||||
|
||||
try:
|
||||
await self._cache.delete(code_key)
|
||||
await self._cache.delete(user_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Change password cleanup failed (user_id={user_id}): {e}')
|
||||
|
||||
self._logger.info(f'Password changed for user_id={user_id}')
|
||||
return True
|
||||
126
src/application/commands/change_password_start.py
Normal file
126
src/application/commands/change_password_start.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ChangePasswordStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
USER_PREFIX = 'change_password:user:'
|
||||
CODE_PREFIX = 'change_password:code:'
|
||||
LOCK_PREFIX = 'change_password:lock:'
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
|
||||
if not user.email:
|
||||
self._logger.warning(f'User {user_id} does not have an email address')
|
||||
raise ApplicationException(404, f'User {user_id} does not have an email address')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{user_id}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Change password throttled by lock (user_id={user_id})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
|
||||
existing = await self._cache.get(user_key)
|
||||
if existing:
|
||||
self._logger.info(f'Change password denied: code already exists for user_id={user_id}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Change password failed: cannot save code hash for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': user.email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'change_password',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Change password code created for user_id={user_id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(user_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish change password email for user_id={user_id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Change password failed: code space exhausted for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
56
src/application/commands/delete_avatar.py
Normal file
56
src/application/commands/delete_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ICache, ILogger, IS3
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class DeleteAvatarCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache, s3: IS3):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
self._s3 = s3
|
||||
|
||||
@transactional
|
||||
async def _load_user(self, user_id: str) -> UserEntity:
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id)
|
||||
self._logger.debug(f'DeleteAvatar _load_user user_id={user_id} has_avatar_link={bool(user.avatar_link)}')
|
||||
return user
|
||||
|
||||
async def __call__(self, user_id: str) -> UserEntity:
|
||||
prior = await self._load_user(user_id)
|
||||
link = prior.avatar_link
|
||||
self._logger.info(f'DeleteAvatar start user_id={user_id} had_link={bool(link)}')
|
||||
if link:
|
||||
key = self._s3.object_key_from_public_url(link)
|
||||
self._logger.debug(f'DeleteAvatar parsed_object_key user_id={user_id} has_key={bool(key)}')
|
||||
if not key:
|
||||
self._logger.warning(
|
||||
f'DeleteAvatar could not parse avatar URL for S3 user_id={user_id} link_len={len(link)}'
|
||||
)
|
||||
if key:
|
||||
self._logger.info(f'DeleteAvatar S3 delete start user_id={user_id} key={key}')
|
||||
try:
|
||||
await self._s3.delete_object(key=key)
|
||||
self._logger.info(f'DeleteAvatar S3 delete done user_id={user_id} key={key}')
|
||||
except ClientError as exc:
|
||||
code = exc.response.get('Error', {}).get('Code', '')
|
||||
if code not in ('NoSuchKey', '404'):
|
||||
self._logger.warning(f'DeleteAvatar S3 delete failed user_id={user_id} code={code}: {exc}')
|
||||
else:
|
||||
self._logger.debug(f'DeleteAvatar S3 object already absent user_id={user_id} code={code}')
|
||||
user = await self._clear_avatar_link(user_id)
|
||||
self._logger.debug(f'DeleteAvatar DB cleared user_id={user_id} entity_has_link={bool(user.avatar_link)}')
|
||||
await self._cache.set_user(user_id, user)
|
||||
self._logger.debug(f'DeleteAvatar cache updated user_id={user_id}')
|
||||
self._logger.info(f'Avatar removed user_id={user_id}')
|
||||
return user
|
||||
|
||||
@transactional
|
||||
async def _clear_avatar_link(self, user_id: str) -> UserEntity:
|
||||
self._logger.debug(f'DeleteAvatar DB transaction set_avatar_link user_id={user_id} link=None')
|
||||
return await self._unit_of_work.user_repository.set_avatar_link(user_id, None)
|
||||
83
src/application/commands/forgot_password_complete.py
Normal file
83
src/application/commands/forgot_password_complete.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ForgotPasswordCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@staticmethod
|
||||
def _normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
email: str,
|
||||
code: str,
|
||||
new_password: str,
|
||||
confirm_password: str,
|
||||
) -> bool:
|
||||
code = (code or '').strip()
|
||||
normalized = self._normalize_email(email)
|
||||
|
||||
EMAIL_PREFIX = 'forgot_password:email:'
|
||||
CODE_PREFIX = 'forgot_password:code:'
|
||||
|
||||
if new_password != confirm_password:
|
||||
self._logger.info('Forgot password failed: passwords do not match')
|
||||
raise ApplicationException(400, 'Passwords do not match')
|
||||
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
cached_email = await self._cache.get(code_key)
|
||||
if not cached_email:
|
||||
self._logger.info('Forgot password failed: code not found')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_email != normalized:
|
||||
self._logger.info('Forgot password failed: code-email mismatch')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
email_key = f'{EMAIL_PREFIX}{normalized}'
|
||||
code_hash = await self._cache.get(email_key)
|
||||
if not code_hash:
|
||||
self._logger.info('Forgot password failed: email key missing')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info('Forgot password failed: code hash mismatch')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_email(normalized)
|
||||
if user is None:
|
||||
self._logger.info('Forgot password failed: user not found after valid code')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
new_password_hash = await self._hash_service.hash(new_password)
|
||||
user = await self._unit_of_work.user_repository.set_password(
|
||||
user_id=user.id,
|
||||
password_hash=new_password_hash,
|
||||
)
|
||||
await self._cache.set_user(user.id, user)
|
||||
|
||||
try:
|
||||
await self._cache.delete(code_key)
|
||||
await self._cache.delete(email_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Forgot password cleanup failed (user_id={user.id}): {e}')
|
||||
|
||||
self._logger.info(f'Password reset via forgot flow for user_id={user.id}')
|
||||
return True
|
||||
132
src/application/commands/forgot_password_start.py
Normal file
132
src/application/commands/forgot_password_start.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class ForgotPasswordStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@staticmethod
|
||||
def _normalize_email(email: str) -> str:
|
||||
return email.strip().lower()
|
||||
|
||||
@transactional
|
||||
async def __call__(self, email: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
EMAIL_PREFIX = 'forgot_password:email:'
|
||||
CODE_PREFIX = 'forgot_password:code:'
|
||||
LOCK_PREFIX = 'forgot_password:lock:'
|
||||
|
||||
normalized = self._normalize_email(email)
|
||||
user = await self._unit_of_work.user_repository.get_user_by_email(normalized)
|
||||
if user is None:
|
||||
self._logger.info(f'Forgot password start: no user for email hash lookup')
|
||||
return True
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{normalized}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Forgot password throttled by lock (user_id={user.id})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
email_key = f'{EMAIL_PREFIX}{normalized}'
|
||||
|
||||
existing = await self._cache.get(email_key)
|
||||
if existing:
|
||||
self._logger.info(f'Forgot password denied: code already exists for user_id={user.id}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, normalized, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Forgot password failed: cannot save code hash for user_id={user.id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': normalized,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'forgot_password',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Forgot password code created for user_id={user.id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(email_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(
|
||||
f'Publish failed and rollback cache failed for user_id={user.id}: {str(rollback_err)}'
|
||||
)
|
||||
|
||||
self._logger.error(f'Failed to publish forgot password email for user_id={user.id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Forgot password failed: code space exhausted for user_id={user.id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
17
src/application/commands/get_me.py
Normal file
17
src/application/commands/get_me.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger, ICache
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class GetMeCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str) -> UserEntity:
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
self._logger.info(f'User ID: {user.id}')
|
||||
return user
|
||||
18
src/application/commands/legal_entity_guard.py
Normal file
18
src/application/commands/legal_entity_guard.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.domain.entities.legal_entity import LegalEntityEntity
|
||||
from src.application.domain.enums.account_type import AccountType
|
||||
from src.application.domain.exceptions import ForbiddenException, NotFoundException
|
||||
|
||||
|
||||
async def require_legal_entity(user_id: str, unit_of_work: IUnitOfWork) -> LegalEntityEntity:
|
||||
user = await unit_of_work.user_repository.get_user_by_id(user_id)
|
||||
if user.account_type != AccountType.LEGAL_ENTITY:
|
||||
raise ForbiddenException(message='B2B access is available for legal entity accounts only')
|
||||
|
||||
legal_entity = await unit_of_work.legal_entity_repository.get_by_user_id(user_id)
|
||||
if legal_entity is None:
|
||||
raise NotFoundException(message='Legal entity profile not found')
|
||||
|
||||
return legal_entity
|
||||
79
src/application/commands/purchase_request_commands.py
Normal file
79
src/application/commands/purchase_request_commands.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands.legal_entity_guard import require_legal_entity
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
|
||||
from src.application.domain.exceptions import NotFoundException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class CreatePurchaseRequestCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
usdt_amount: Decimal,
|
||||
comment: str | None = None,
|
||||
target_wallet_chain: str | None = None,
|
||||
target_wallet_address: str | None = None,
|
||||
) -> PurchaseRequestEntity:
|
||||
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
|
||||
item = await self._unit_of_work.purchase_request_repository.create(
|
||||
organization_id=legal_entity.id,
|
||||
usdt_amount=usdt_amount,
|
||||
comment=comment,
|
||||
target_wallet_chain=target_wallet_chain,
|
||||
target_wallet_address=target_wallet_address,
|
||||
)
|
||||
self._logger.info(f'Purchase request created id={item.id} organization_id={legal_entity.id}')
|
||||
return item
|
||||
|
||||
|
||||
class ListPurchaseRequestsCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[PurchaseRequestEntity], int]:
|
||||
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
|
||||
items = await self._unit_of_work.purchase_request_repository.list_by_organization(
|
||||
organization_id=legal_entity.id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
total = await self._unit_of_work.purchase_request_repository.count_by_organization(
|
||||
organization_id=legal_entity.id,
|
||||
status=status,
|
||||
)
|
||||
return items, total
|
||||
|
||||
|
||||
class GetPurchaseRequestCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str, request_id: str) -> PurchaseRequestEntity:
|
||||
legal_entity = await require_legal_entity(user_id, self._unit_of_work)
|
||||
item = await self._unit_of_work.purchase_request_repository.get_by_id(request_id)
|
||||
if item.organization_id != legal_entity.id:
|
||||
raise NotFoundException(message='Purchase request not found')
|
||||
return item
|
||||
100
src/application/commands/set_avatar.py
Normal file
100
src/application/commands/set_avatar.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from PIL import UnidentifiedImageError
|
||||
from ulid import ULID
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ICache, ILogger, IS3
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import BadRequestException, ServiceUnavailableException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
from src.infrastructure.media.webp import image_bytes_to_webp
|
||||
|
||||
|
||||
class SetAvatarCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache, s3: IS3):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
self._s3 = s3
|
||||
|
||||
@transactional
|
||||
async def _load_user(self, user_id: str) -> UserEntity:
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id)
|
||||
self._logger.debug(f'Avatar _load_user user_id={user_id} has_avatar_link={bool(user.avatar_link)}')
|
||||
return user
|
||||
|
||||
async def __call__(self, user_id: str, image_bytes: bytes) -> tuple[UserEntity, int]:
|
||||
prior = await self._load_user(user_id)
|
||||
old_link = prior.avatar_link
|
||||
self._logger.info(
|
||||
f'SetAvatar start user_id={user_id} input_bytes={len(image_bytes)} had_previous_link={bool(old_link)}'
|
||||
)
|
||||
try:
|
||||
webp_bytes = image_bytes_to_webp(image_bytes)
|
||||
except UnidentifiedImageError as exc:
|
||||
raise BadRequestException(message='Unsupported or corrupt image') from exc
|
||||
except Exception as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise BadRequestException(message='Could not process image') from exc
|
||||
|
||||
self._logger.debug(f'SetAvatar webp_ready bytes={len(webp_bytes)}')
|
||||
|
||||
pid = user_id.replace('/', '').replace('.', '_')
|
||||
name_id = str(ULID())
|
||||
ts = int(datetime.now(timezone.utc).timestamp() * 1000)
|
||||
prefix = settings.S3_AVATAR_KEY_PREFIX.strip().strip('/')
|
||||
fname = f'{name_id}_{pid}_{ts}.webp'
|
||||
object_key = f'{prefix}/{fname}' if prefix else fname
|
||||
|
||||
self._logger.info(f'SetAvatar S3 upload start user_id={user_id} key={object_key} webp_bytes={len(webp_bytes)}')
|
||||
|
||||
try:
|
||||
url = await self._s3.upload_bytes(key=object_key, body=webp_bytes, content_type='image/webp')
|
||||
except ClientError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise ServiceUnavailableException(message='S3 upload failed') from exc
|
||||
|
||||
self._logger.info(f'SetAvatar S3 upload done user_id={user_id} key={object_key} public_url_len={len(url)}')
|
||||
|
||||
user = await self._save_avatar_link(user_id, url)
|
||||
self._logger.info(
|
||||
f'SetAvatar DB updated user_id={user_id} key={object_key} '
|
||||
f'entity_avatar_link_len={len(user.avatar_link or "")}'
|
||||
)
|
||||
await self._cache.set_user(user_id, user)
|
||||
self._logger.debug(f'SetAvatar cache updated user_id={user_id}')
|
||||
|
||||
if old_link:
|
||||
old_key = self._s3.object_key_from_public_url(old_link)
|
||||
if not old_key:
|
||||
self._logger.warning(
|
||||
f'SetAvatar could not parse old avatar URL for S3 delete user_id={user_id} '
|
||||
f'old_link_len={len(old_link)}'
|
||||
)
|
||||
elif old_key == object_key:
|
||||
self._logger.debug(f'SetAvatar skip delete same object key user_id={user_id} key={object_key}')
|
||||
else:
|
||||
self._logger.info(f'SetAvatar S3 delete old object user_id={user_id} old_key={old_key}')
|
||||
try:
|
||||
await self._s3.delete_object(key=old_key)
|
||||
self._logger.info(f'SetAvatar S3 old object removed user_id={user_id} old_key={old_key}')
|
||||
except ClientError as exc:
|
||||
code = exc.response.get('Error', {}).get('Code', '')
|
||||
if code not in ('NoSuchKey', '404'):
|
||||
self._logger.warning(f'S3 delete old avatar failed user_id={user_id} code={code}: {exc}')
|
||||
else:
|
||||
self._logger.debug(f'SetAvatar old object already gone user_id={user_id} code={code}')
|
||||
|
||||
self._logger.info(f'Avatar set for user_id={user_id} key={object_key}')
|
||||
return user, len(webp_bytes)
|
||||
|
||||
@transactional
|
||||
async def _save_avatar_link(self, user_id: str, avatar_link: str) -> UserEntity:
|
||||
self._logger.debug(f'SetAvatar DB transaction set_avatar_link user_id={user_id} link_len={len(avatar_link)}')
|
||||
return await self._unit_of_work.user_repository.set_avatar_link(user_id, avatar_link)
|
||||
68
src/application/commands/set_encrypted_mnemonic_complete.py
Normal file
68
src/application/commands/set_encrypted_mnemonic_complete.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class SetEncryptedMnemonicCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, *, user_id: str, code: str, encrypted_mnemonic: str) -> UserEntity:
|
||||
code = (code or '').strip()
|
||||
|
||||
USER_PREFIX = 'encrypted_mnemonic:user:'
|
||||
CODE_PREFIX = 'encrypted_mnemonic:code:'
|
||||
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
if user.encrypted_mnemonic is not None:
|
||||
self._logger.info(f'Encrypted mnemonic already set for user_id={user_id}')
|
||||
raise ApplicationException(409, 'Encrypted mnemonic already set and cannot be changed')
|
||||
|
||||
cached_user_id = await self._cache.get(code_key)
|
||||
if not cached_user_id:
|
||||
self._logger.info(f'Encrypted mnemonic set failed: code not found (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_user_id != user_id:
|
||||
self._logger.info(f'Encrypted mnemonic set failed: code-user mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(user_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Encrypted mnemonic set failed: user key missing (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Encrypted mnemonic set failed: code hash mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
user = await self._unit_of_work.user_repository.set_encrypted_mnemonic(
|
||||
user_id=user_id,
|
||||
encrypted_mnemonic=encrypted_mnemonic,
|
||||
)
|
||||
await self._cache.set_user(user_id, user)
|
||||
|
||||
try:
|
||||
await self._cache.delete(code_key)
|
||||
await self._cache.delete(user_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Encrypted mnemonic set cleanup failed (user_id={user_id}): {e}')
|
||||
|
||||
self._logger.info(f'Encrypted mnemonic set for user_id={user_id}')
|
||||
return user
|
||||
130
src/application/commands/set_encrypted_mnemonic_start.py
Normal file
130
src/application/commands/set_encrypted_mnemonic_start.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class SetEncryptedMnemonicStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
USER_PREFIX = 'encrypted_mnemonic:user:'
|
||||
CODE_PREFIX = 'encrypted_mnemonic:code:'
|
||||
LOCK_PREFIX = 'encrypted_mnemonic:lock:'
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
|
||||
if user.encrypted_mnemonic is not None:
|
||||
self._logger.info(f'Encrypted mnemonic already set for user_id={user_id}')
|
||||
raise ApplicationException(409, 'Encrypted mnemonic already set and cannot be changed')
|
||||
|
||||
if not user.email:
|
||||
self._logger.warning(f'User {user_id} does not have an email address')
|
||||
raise ApplicationException(404, f'User {user_id} does not have an email address')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{user_id}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Encrypted mnemonic set throttled by lock (user_id={user_id})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
|
||||
existing = await self._cache.get(user_key)
|
||||
if existing:
|
||||
self._logger.info(f'Encrypted mnemonic set denied: code already exists for user_id={user_id}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Encrypted mnemonic set failed: cannot save code hash for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': user.email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'encrypted_mnemonic_set',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Encrypted mnemonic set code created for user_id={user_id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(user_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish encrypted mnemonic set email for user_id={user_id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Encrypted mnemonic set failed: code space exhausted for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
18
src/application/commands/set_phone.py
Normal file
18
src/application/commands/set_phone.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger, ICache
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class SetPhoneCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, logger: ILogger, cache: ICache):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._logger = logger
|
||||
self._cache = cache
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str, phone: str) -> UserEntity:
|
||||
user = await self._unit_of_work.user_repository.set_phone(user_id=user_id, phone=phone)
|
||||
await self._cache.set_user(user_id, user)
|
||||
self._logger.info(f'Set phone for user {user_id}')
|
||||
return user
|
||||
76
src/application/commands/update_bank_details_complete.py
Normal file
76
src/application/commands/update_bank_details_complete.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UpdateBankDetailsCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
code: str,
|
||||
passport_data: str | None = None,
|
||||
inn: str | None = None,
|
||||
erc20: str | None = None,
|
||||
) -> UserEntity:
|
||||
code = (code or '').strip()
|
||||
|
||||
USER_PREFIX = 'bank_details:user:'
|
||||
CODE_PREFIX = 'bank_details:code:'
|
||||
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
cached_user_id = await self._cache.get(code_key)
|
||||
if not cached_user_id:
|
||||
self._logger.info(f'Bank details update failed: code not found (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_user_id != user_id:
|
||||
self._logger.info(f'Bank details update failed: code-user mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(user_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Bank details update failed: user key missing (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Bank details update failed: code hash mismatch (user_id={user_id})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
fields = {}
|
||||
if passport_data is not None:
|
||||
fields['passport_data'] = passport_data
|
||||
if inn is not None:
|
||||
fields['inn'] = inn
|
||||
if erc20 is not None:
|
||||
fields['erc20'] = erc20
|
||||
|
||||
user = await self._unit_of_work.user_repository.set_bank_details(user_id, **fields)
|
||||
await self._cache.set_user(user_id, user)
|
||||
|
||||
try:
|
||||
await self._cache.delete(code_key)
|
||||
await self._cache.delete(user_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Bank details update cleanup failed (user_id={user_id}): {e}')
|
||||
|
||||
self._logger.info(f'Bank details updated for user_id={user_id}, fields={list(fields.keys())}')
|
||||
return user
|
||||
126
src/application/commands/update_bank_details_start.py
Normal file
126
src/application/commands/update_bank_details_start.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UpdateBankDetailsStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, user_id: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
USER_PREFIX = 'bank_details:user:'
|
||||
CODE_PREFIX = 'bank_details:code:'
|
||||
LOCK_PREFIX = 'bank_details:lock:'
|
||||
|
||||
user = await self._unit_of_work.user_repository.get_user_by_id(user_id=user_id)
|
||||
|
||||
if not user.email:
|
||||
self._logger.warning(f'User {user_id} does not have an email address')
|
||||
raise ApplicationException(status_code=404, message=f'User {user_id} does not have an email address')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{user_id}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Bank details update throttled by lock (user_id={user_id})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
user_key = f'{USER_PREFIX}{user_id}'
|
||||
|
||||
existing = await self._cache.get(user_key)
|
||||
if existing:
|
||||
self._logger.info(f'Bank details update denied: code already exists for user_id={user_id}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, user_id, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(user_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Bank details update failed: cannot save code hash for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'user-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': user.email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'bank_details_update',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'Bank details update code created for user_id={user_id}')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(user_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for user_id={user_id}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish bank details update email for user_id={user_id}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
return True
|
||||
|
||||
self._logger.error(f'Bank details update failed: code space exhausted for user_id={user_id}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
7
src/application/contracts/__init__.py
Normal file
7
src/application/contracts/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
from src.application.contracts.i_jwt_service import IJwtService
|
||||
from src.application.contracts.i_csrf_service import ICsrfService
|
||||
from src.application.contracts.i_cache import ICache
|
||||
from src.application.contracts.i_hash_service import IHashService
|
||||
from src.application.contracts.i_queue_messanger import IQueueMessanger
|
||||
from src.application.contracts.i_s3 import IS3
|
||||
30
src/application/contracts/i_cache.py
Normal file
30
src/application/contracts/i_cache.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class ICache(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
raise NotImplementedError
|
||||
26
src/application/contracts/i_csrf_service.py
Normal file
26
src/application/contracts/i_csrf_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
|
||||
class ICsrfService(ABC):
|
||||
@abstractmethod
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify_pair(
|
||||
self,
|
||||
cookie_token: Optional[str],
|
||||
header_token: Optional[str],
|
||||
expected_subject: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
12
src/application/contracts/i_hash_service.py
Normal file
12
src/application/contracts/i_hash_service.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IHashService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def hash(self, value: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
raise NotImplementedError
|
||||
10
src/application/contracts/i_jwt_service.py
Normal file
10
src/application/contracts/i_jwt_service.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
|
||||
|
||||
class IJwtService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
raise NotImplementedError
|
||||
68
src/application/contracts/i_logger.py
Normal file
68
src/application/contracts/i_logger.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Protocol, Optional, Callable
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
|
||||
|
||||
class ILogger(Protocol):
|
||||
"""Interface for synchronous logger with ContextVar support for trace_id."""
|
||||
|
||||
log_format: LogFormat
|
||||
min_level: LogLevel
|
||||
id_generator: Optional[Callable[[], str]]
|
||||
instance_id: str
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
"""Set log format using LogFormat enum"""
|
||||
...
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
"""Set minimum log level"""
|
||||
...
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
"""Create and set new trace_id in context"""
|
||||
...
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
"""Set existing trace_id in context"""
|
||||
...
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
"""Get current trace_id from context"""
|
||||
...
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
"""Clear the trace_id in the context"""
|
||||
...
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
"""Set service instance id (ULID recommended)"""
|
||||
...
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
"""Get current service instance id"""
|
||||
...
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log debug message"""
|
||||
...
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log info message"""
|
||||
...
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log warning message"""
|
||||
...
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log error message"""
|
||||
...
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Log critical message"""
|
||||
...
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
"""Log exception with traceback"""
|
||||
...
|
||||
40
src/application/contracts/i_queue_messanger.py
Normal file
40
src/application/contracts/i_queue_messanger.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
class IQueueMessanger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
17
src/application/contracts/i_s3.py
Normal file
17
src/application/contracts/i_s3.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IS3(Protocol):
|
||||
async def upload_bytes(self, *, key: str, body: bytes, content_type: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
async def delete_object(self, *, key: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
def object_key_from_public_url(self, url: str) -> str | None:
|
||||
...
|
||||
2
src/application/domain/dto/__init__.py
Normal file
2
src/application/domain/dto/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.dto.token import AccessTokenPayload, AuthContext
|
||||
from src.application.domain.dto.keys import JwtPublicKey, JwtPublicKeySet
|
||||
20
src/application/domain/dto/keys.py
Normal file
20
src/application/domain/dto/keys.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKey:
|
||||
kid: str
|
||||
public_key_pem: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtPublicKeySet:
|
||||
active: JwtPublicKey
|
||||
previous: Optional[JwtPublicKey] = None
|
||||
|
||||
def public_keys_by_kid(self) -> Dict[str, str]:
|
||||
out = {self.active.kid: self.active.public_key_pem}
|
||||
if self.previous:
|
||||
out[self.previous.kid] = self.previous.public_key_pem
|
||||
return out
|
||||
18
src/application/domain/dto/token.py
Normal file
18
src/application/domain/dto/token.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AccessTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
user_id: str
|
||||
sid: str
|
||||
token: AccessTokenPayload
|
||||
5
src/application/domain/entities/__init__.py
Normal file
5
src/application/domain/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
from src.application.domain.entities.session import SessionEntity
|
||||
|
||||
|
||||
__all__ = ['UserEntity', 'SessionEntity']
|
||||
24
src/application/domain/entities/legal_entity.py
Normal file
24
src/application/domain/entities/legal_entity.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LegalEntityEntity:
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
inn: str
|
||||
status: str
|
||||
short_name: str | None = None
|
||||
ogrn: str | None = None
|
||||
kpp: str | None = None
|
||||
legal_address: str | None = None
|
||||
actual_address: str | None = None
|
||||
bank_details: dict[str, Any] | None = None
|
||||
contact_person: str | None = None
|
||||
contact_phone: str | None = None
|
||||
kyc_verified: bool = True
|
||||
kyc_verified_at: datetime | None = None
|
||||
24
src/application/domain/entities/purchase_request.py
Normal file
24
src/application/domain/entities/purchase_request.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PurchaseRequestEntity:
|
||||
id: str
|
||||
organization_id: str
|
||||
status: str
|
||||
usdt_amount: Decimal
|
||||
rub_amount: Decimal | None
|
||||
exchange_rate: Decimal | None
|
||||
service_fee_percent: Decimal | None
|
||||
comment: str | None
|
||||
admin_comment: str | None
|
||||
target_wallet_chain: str | None
|
||||
target_wallet_address: str | None
|
||||
tx_hash: str | None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
20
src/application/domain/entities/session.py
Normal file
20
src/application/domain/entities/session.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionEntity:
|
||||
sid: str
|
||||
user_id: str
|
||||
device_id: str
|
||||
|
||||
revoked_at: datetime | None
|
||||
last_seen_at: datetime
|
||||
|
||||
refresh_jti_hash: str | None
|
||||
refresh_expires_at: datetime | None
|
||||
|
||||
user_agent: str | None = None
|
||||
first_ip: str | None = None
|
||||
last_ip: str | None = None
|
||||
31
src/application/domain/entities/user.py
Normal file
31
src/application/domain/entities/user.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserEntity:
|
||||
id: str | None = None
|
||||
email: str | None = None
|
||||
password_hash: str | None = None
|
||||
|
||||
first_name: str | None = None
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
|
||||
encrypted_mnemonic: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
passport_data: str | None = None
|
||||
inn: str | None = None
|
||||
erc20: str | None = None
|
||||
avatar_link: str | None = None
|
||||
|
||||
kyc_verified: bool | None = None
|
||||
is_deleted: bool | None = None
|
||||
account_type: str | None = None
|
||||
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
2
src/application/domain/enums/__init__.py
Normal file
2
src/application/domain/enums/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
6
src/application/domain/enums/account_type.py
Normal file
6
src/application/domain/enums/account_type.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AccountType(StrEnum):
|
||||
INDIVIDUAL = 'individual'
|
||||
LEGAL_ENTITY = 'legal_entity'
|
||||
7
src/application/domain/enums/log_format.py
Normal file
7
src/application/domain/enums/log_format.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
"""Enum for supported log formats"""
|
||||
TEXT = 'text'
|
||||
JSON = 'json'
|
||||
54
src/application/domain/enums/log_level.py
Normal file
54
src/application/domain/enums/log_level.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARNING = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
EXCEPTION = 60
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"[{self.value}, '{self.name}']"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value == other.value
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value < other.value
|
||||
if isinstance(other, int):
|
||||
return self.value < other
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value <= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value <= other
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value > other.value
|
||||
if isinstance(other, int):
|
||||
return self.value > other
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value >= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value >= other
|
||||
return NotImplemented
|
||||
11
src/application/domain/exceptions/__init__.py
Normal file
11
src/application/domain/exceptions/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from src.application.domain.exceptions.application_exceptions import (
|
||||
ApplicationException,
|
||||
BadRequestException,
|
||||
ConflictException,
|
||||
ForbiddenException,
|
||||
InternalException,
|
||||
NotFoundException,
|
||||
ServiceUnavailableException,
|
||||
TooManyRequestsException,
|
||||
UnauthorizedException,
|
||||
)
|
||||
59
src/application/domain/exceptions/application_exceptions.py
Normal file
59
src/application/domain/exceptions/application_exceptions.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class ApplicationException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.status_code}: {self.message}'
|
||||
|
||||
|
||||
class BadRequestException(ApplicationException):
|
||||
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
|
||||
super().__init__(400, message, headers)
|
||||
|
||||
|
||||
class UnauthorizedException(ApplicationException):
|
||||
def __init__(self, message: str = 'Unauthorized', headers: Mapping[str, str] | None = None):
|
||||
super().__init__(401, message, headers)
|
||||
|
||||
|
||||
class ForbiddenException(ApplicationException):
|
||||
def __init__(self, message: str = 'Forbidden', headers: Mapping[str, str] | None = None):
|
||||
super().__init__(403, message, headers)
|
||||
|
||||
|
||||
class NotFoundException(ApplicationException):
|
||||
def __init__(self, message: str = 'Not found', headers: Mapping[str, str] | None = None):
|
||||
super().__init__(404, message, headers)
|
||||
|
||||
|
||||
class ConflictException(ApplicationException):
|
||||
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
|
||||
super().__init__(409, message, headers)
|
||||
|
||||
|
||||
class TooManyRequestsException(ApplicationException):
|
||||
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
|
||||
super().__init__(429, message, headers)
|
||||
|
||||
|
||||
class ServiceUnavailableException(ApplicationException):
|
||||
def __init__(self, message: str, headers: Mapping[str, str] | None = None):
|
||||
super().__init__(503, message, headers)
|
||||
|
||||
|
||||
class InternalException(ApplicationException):
|
||||
def __init__(self, message: str = 'Internal Server Error', headers: Mapping[str, str] | None = None):
|
||||
super().__init__(500, message, headers)
|
||||
21
src/application/domain/password_policy.py
Normal file
21
src/application/domain/password_policy.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import re
|
||||
|
||||
SPECIAL_CHARS = '!@#$%^&*()_+-=.,:;?/[]{}<>'
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> str:
|
||||
if re.search(r'\s', password):
|
||||
raise ValueError('Password must not contain whitespace')
|
||||
if len(password) < 12:
|
||||
raise ValueError('Password must be at least 12 characters')
|
||||
if not re.search(r'[a-z]', password):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not re.search(r'[A-Z]', password):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not re.search(r'\d', password):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(c in SPECIAL_CHARS for c in password):
|
||||
raise ValueError(
|
||||
'Password must contain at least one special character from: !@#$%^&*()_+-=.,:;?/[]{}<>'
|
||||
)
|
||||
return password
|
||||
2
src/infrastructure/cache/__init__.py
vendored
Normal file
2
src/infrastructure/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.cache.client import create_redis_client
|
||||
from src.infrastructure.cache.keydb_client import KeydbCache
|
||||
16
src/infrastructure/cache/client.py
vendored
Normal file
16
src/infrastructure/cache/client.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=50,
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class KeydbCache(ICache):
|
||||
USER_PREFIX = 'user:me'
|
||||
|
||||
def __init__(self, redis_client: Redis):
|
||||
self._r = redis_client
|
||||
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl))
|
||||
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl, nx=True))
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
return (await self._r.delete(key)) > 0
|
||||
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
|
||||
if raw is None:
|
||||
return None
|
||||
return orjson.loads(raw)
|
||||
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
data = orjson.dumps({
|
||||
'id': user.id,
|
||||
'email': user.email,
|
||||
'first_name': user.first_name,
|
||||
'middle_name': user.middle_name,
|
||||
'last_name': user.last_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'encrypted_mnemonic': user.encrypted_mnemonic,
|
||||
'phone': user.phone,
|
||||
'passport_data': user.passport_data,
|
||||
'inn': user.inn,
|
||||
'erc20': user.erc20,
|
||||
'avatar_link': user.avatar_link,
|
||||
'kyc_verified': user.kyc_verified,
|
||||
'is_deleted': user.is_deleted,
|
||||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||||
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
})
|
||||
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import settings
|
||||
311
src/infrastructure/config/settings.py
Normal file
311
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,311 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Literal, Mapping
|
||||
from urllib.parse import quote
|
||||
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
from pydantic import Field, PrivateAttr
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from src.infrastructure.vault.client import VaultClient
|
||||
|
||||
env_file = find_dotenv('.env')
|
||||
if env_file:
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
def _as_int(value: object, default: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return int(str(value).strip())
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=True, extra='ignore')
|
||||
|
||||
_vault_database_secrets: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
VAULT_ADDR: str = 'https://corp.vault.elcsa.ru'
|
||||
VAULT_ROLE_ID: str = ''
|
||||
VAULT_SECRET_ID: str = ''
|
||||
VAULT_NAMESPACE: str | None = None
|
||||
VAULT_MOUNT_POINT: str = 'dev-secrets'
|
||||
VAULT_DATABASE_SECRET_PATH: str = 'database'
|
||||
VAULT_RABBIT_SECRET_PATH: str = 'rabbitmq'
|
||||
VAULT_CSRF_SECRET_PATH: str = 'csrf'
|
||||
VAULT_DOCS_SECRET_PATH: str = 'docs'
|
||||
VAULT_JWT_KID_PATH: str = 'jwt/kid'
|
||||
VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids'
|
||||
VAULT_S3_SECRET_PATH: str = 's3/avatars'
|
||||
|
||||
DATABASE_URL_DIRECT: str | None = Field(default=None, validation_alias='DATABASE_URL')
|
||||
DATABASE_HOST: str = ''
|
||||
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
|
||||
DATABASE_NAME: str = ''
|
||||
DATABASE_USER: str = ''
|
||||
DATABASE_PASSWORD: str = ''
|
||||
|
||||
DATABASE_POOL_SIZE: int = 10
|
||||
DATABASE_MAX_OVERFLOW: int = 20
|
||||
DATABASE_POOL_TIMEOUT: int = 30
|
||||
DATABASE_POOL_RECYCLE: int = 3600
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
CSRF_SECRET_KEY: str = Field(
|
||||
default='change-me-change-me-change-me-change-me',
|
||||
min_length=32,
|
||||
)
|
||||
|
||||
CSRF_COOKIE_SECURE: bool = False
|
||||
CSRF_COOKIE_HTTPONLY: bool = True
|
||||
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
|
||||
CSRF_COOKIE_PATH: str = '/'
|
||||
CSRF_COOKIE_DOMAIN: str | None = None
|
||||
|
||||
DOCS_USERNAME: str = 'admin'
|
||||
DOCS_PASSWORD: str = 'admin'
|
||||
|
||||
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
|
||||
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
|
||||
JWT_ISSUER: str | None = None
|
||||
JWT_AUDIENCE: str | None = None
|
||||
JWT_ALGORITHM: str = 'RS256'
|
||||
JWT_KEYS_REFRESH_SECONDS: int = 3600
|
||||
|
||||
REDIS_HOST: str = 'keydb'
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_PASSWORD: str | None = None
|
||||
REDIS_DB: int = 0
|
||||
|
||||
RABBIT_HOST: str = 'localhost'
|
||||
RABBIT_PORT: int = 5672
|
||||
RABBIT_USER: str = 'guest'
|
||||
RABBIT_PASSWORD: str = 'guest'
|
||||
RABBIT_VHOST: str = '/'
|
||||
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
|
||||
|
||||
S3_BUCKET: str = ''
|
||||
S3_REGION: str = 'us-east-1'
|
||||
S3_ACCESS_KEY_ID: str = ''
|
||||
S3_SECRET_ACCESS_KEY: str = ''
|
||||
S3_ENDPOINT_URL: str = ''
|
||||
S3_PUBLIC_BASE_URL: str = ''
|
||||
S3_REGRU_PUBLIC_WEBSITE_HOST: bool = False
|
||||
S3_AVATAR_KEY_PREFIX: str = 'avatars'
|
||||
|
||||
LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
|
||||
LOG_FORMAT: Literal['JSON', 'TEXT'] = 'JSON'
|
||||
|
||||
def _get_vault_secret(self, secrets: dict[str, Any], *keys: str) -> str:
|
||||
for key in keys:
|
||||
value = secrets.get(key)
|
||||
if value is not None and str(value).strip() != '':
|
||||
return str(value)
|
||||
return ''
|
||||
|
||||
def _reset_s3_config(self) -> None:
|
||||
object.__setattr__(self, 'S3_BUCKET', '')
|
||||
object.__setattr__(self, 'S3_ACCESS_KEY_ID', '')
|
||||
object.__setattr__(self, 'S3_SECRET_ACCESS_KEY', '')
|
||||
object.__setattr__(self, 'S3_ENDPOINT_URL', '')
|
||||
object.__setattr__(self, 'S3_PUBLIC_BASE_URL', '')
|
||||
object.__setattr__(self, 'S3_REGION', 'us-east-1')
|
||||
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', False)
|
||||
object.__setattr__(self, 'S3_AVATAR_KEY_PREFIX', 'avatars')
|
||||
|
||||
@staticmethod
|
||||
def _vault_kv(mapping: Mapping[str, Any], *keys: str) -> Any:
|
||||
for k in keys:
|
||||
if k in mapping and mapping[k] is not None:
|
||||
return mapping[k]
|
||||
return None
|
||||
|
||||
def _apply_s3_from_vault_secret(self, s3: dict[str, Any]) -> None:
|
||||
bucket_name = (
|
||||
self._vault_kv(s3, 'bucket_name', 'BUCKET_NAME', 'bucket')
|
||||
or self._vault_kv(s3, 'S3_BUCKET', 'bucketName')
|
||||
)
|
||||
endpoint_url = (
|
||||
self._vault_kv(s3, 's3_endpoint_url', 'S3_ENDPOINT_URL', 'endpoint_url', 'ENDPOINT_URL')
|
||||
or self._vault_kv(s3, 'endpoint')
|
||||
)
|
||||
ak = (
|
||||
self._vault_kv(s3, 's3_access_key_id', 'S3_ACCESS_KEY_ID', 'ACCESS_KEY_ID', 'access_key_id')
|
||||
or self._vault_kv(s3, 'AWS_ACCESS_KEY_ID')
|
||||
)
|
||||
sk = (
|
||||
self._vault_kv(s3, 's3_secret_access_key', 'S3_SECRET_ACCESS_KEY', 'SECRET_ACCESS_KEY')
|
||||
or self._vault_kv(s3, 'AWS_SECRET_ACCESS_KEY')
|
||||
)
|
||||
if bucket_name is None or str(bucket_name).strip() == '':
|
||||
raise ValueError('Vault S3 secret must contain bucket_name')
|
||||
if endpoint_url is None or str(endpoint_url).strip() == '':
|
||||
raise ValueError('Vault S3 secret must contain s3_endpoint_url')
|
||||
if ak is None or str(ak).strip() == '':
|
||||
raise ValueError('Vault S3 secret must contain s3_access_key_id')
|
||||
if sk is None or str(sk).strip() == '':
|
||||
raise ValueError('Vault S3 secret must contain s3_secret_access_key')
|
||||
object.__setattr__(self, 'S3_BUCKET', str(bucket_name).strip())
|
||||
object.__setattr__(self, 'S3_ENDPOINT_URL', str(endpoint_url).strip())
|
||||
object.__setattr__(self, 'S3_ACCESS_KEY_ID', str(ak).strip())
|
||||
object.__setattr__(self, 'S3_SECRET_ACCESS_KEY', str(sk).strip())
|
||||
region = (
|
||||
self._vault_kv(s3, 's3_region', 'S3_REGION', 'region')
|
||||
)
|
||||
if region is not None and str(region).strip() != '':
|
||||
object.__setattr__(self, 'S3_REGION', str(region).strip())
|
||||
public_base = (
|
||||
self._vault_kv(s3, 's3_public_base_url', 'S3_PUBLIC_BASE_URL', 'public_base_url')
|
||||
or self._vault_kv(s3, 'public_url')
|
||||
)
|
||||
if public_base is not None and str(public_base).strip() != '':
|
||||
object.__setattr__(self, 'S3_PUBLIC_BASE_URL', str(public_base).strip())
|
||||
prefix = self._vault_kv(s3, 'avatar_key_prefix', 'S3_AVATAR_KEY_PREFIX', 's3_avatar_key_prefix')
|
||||
if prefix is not None and str(prefix).strip() != '':
|
||||
object.__setattr__(self, 'S3_AVATAR_KEY_PREFIX', str(prefix).strip())
|
||||
rf = (
|
||||
self._vault_kv(s3, 's3_reg_ru_public_website_host', 'S3_REGRU_PUBLIC_WEBSITE_HOST')
|
||||
)
|
||||
if rf is not None:
|
||||
v = str(rf).strip().lower()
|
||||
if v in {'1', 'true', 'yes', 'on'}:
|
||||
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', True)
|
||||
elif v in {'0', 'false', 'no', 'off'}:
|
||||
object.__setattr__(self, 'S3_REGRU_PUBLIC_WEBSITE_HOST', False)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._reset_s3_config()
|
||||
if not self.VAULT_ROLE_ID.strip() or not self.VAULT_SECRET_ID.strip():
|
||||
if not self.DATABASE_URL:
|
||||
raise ValueError(
|
||||
'Set VAULT_ROLE_ID and VAULT_SECRET_ID for Vault, or set DATABASE_URL '
|
||||
'(or DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME) in the environment',
|
||||
)
|
||||
return
|
||||
|
||||
client = VaultClient(
|
||||
addr=self.VAULT_ADDR,
|
||||
role_id=self.VAULT_ROLE_ID,
|
||||
secret_id=self.VAULT_SECRET_ID,
|
||||
namespace=self.VAULT_NAMESPACE,
|
||||
mount_point=self.VAULT_MOUNT_POINT,
|
||||
)
|
||||
|
||||
db = client.read_secret(self.VAULT_DATABASE_SECRET_PATH)
|
||||
object.__setattr__(self, '_vault_database_secrets', db)
|
||||
|
||||
def kv(d: dict[str, Any], *keys: str) -> Any:
|
||||
for k in keys:
|
||||
if k in d and d[k] is not None:
|
||||
return d[k]
|
||||
return None
|
||||
|
||||
if kv(db, 'HOST', 'host') is not None:
|
||||
object.__setattr__(self, 'DATABASE_HOST', str(kv(db, 'HOST', 'host')))
|
||||
if kv(db, 'PORT', 'port') is not None:
|
||||
object.__setattr__(self, 'DATABASE_PORT', _as_int(kv(db, 'PORT', 'port'), self.DATABASE_PORT))
|
||||
if kv(db, 'NAME', 'name') is not None:
|
||||
object.__setattr__(self, 'DATABASE_NAME', str(kv(db, 'NAME', 'name')))
|
||||
if kv(db, 'USER', 'user') is not None:
|
||||
object.__setattr__(self, 'DATABASE_USER', str(kv(db, 'USER', 'user')))
|
||||
if kv(db, 'PASSWORD', 'password') is not None:
|
||||
object.__setattr__(self, 'DATABASE_PASSWORD', str(kv(db, 'PASSWORD', 'password')))
|
||||
|
||||
rabbit = client.read_secret_optional(self.VAULT_RABBIT_SECRET_PATH)
|
||||
if rabbit:
|
||||
if kv(rabbit, 'HOST', 'host') is not None:
|
||||
object.__setattr__(self, 'RABBIT_HOST', str(kv(rabbit, 'HOST', 'host')))
|
||||
if kv(rabbit, 'PORT', 'port') is not None:
|
||||
object.__setattr__(self, 'RABBIT_PORT', _as_int(kv(rabbit, 'PORT', 'port'), self.RABBIT_PORT))
|
||||
if kv(rabbit, 'USER', 'user') is not None:
|
||||
object.__setattr__(self, 'RABBIT_USER', str(kv(rabbit, 'USER', 'user')))
|
||||
if kv(rabbit, 'PASSWORD', 'password') is not None:
|
||||
object.__setattr__(self, 'RABBIT_PASSWORD', str(kv(rabbit, 'PASSWORD', 'password')))
|
||||
if kv(rabbit, 'VHOST', 'vhost') is not None:
|
||||
object.__setattr__(self, 'RABBIT_VHOST', str(kv(rabbit, 'VHOST', 'vhost')))
|
||||
|
||||
csrf = client.read_secret_optional(self.VAULT_CSRF_SECRET_PATH)
|
||||
if csrf and kv(csrf, 'KEY', 'key') is not None:
|
||||
key = str(kv(csrf, 'KEY', 'key'))
|
||||
if len(key) >= 32:
|
||||
object.__setattr__(self, 'CSRF_SECRET_KEY', key)
|
||||
|
||||
docs = client.read_secret_optional(self.VAULT_DOCS_SECRET_PATH)
|
||||
if docs:
|
||||
u = docs.get('DOCS_USERNAME') or docs.get('USERNAME')
|
||||
p = docs.get('DOCS_PASSWORD') or docs.get('PASSWORD')
|
||||
if u is not None:
|
||||
object.__setattr__(self, 'DOCS_USERNAME', str(u))
|
||||
if p is not None:
|
||||
object.__setattr__(self, 'DOCS_PASSWORD', str(p))
|
||||
|
||||
s3_rel_path = self.VAULT_S3_SECRET_PATH.strip()
|
||||
if s3_rel_path:
|
||||
s3_secret_data = client.read_secret_optional(s3_rel_path)
|
||||
if s3_secret_data:
|
||||
self._apply_s3_from_vault_secret(s3_secret_data)
|
||||
|
||||
if not self.DATABASE_URL:
|
||||
raise ValueError('Database URL could not be built from Vault database secret')
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
direct = (self.DATABASE_URL_DIRECT or '').strip()
|
||||
if direct:
|
||||
return direct
|
||||
|
||||
ready_url = self._get_vault_secret(
|
||||
self._vault_database_secrets,
|
||||
'DATABASE_URL',
|
||||
'database_url',
|
||||
)
|
||||
if ready_url:
|
||||
return ready_url
|
||||
|
||||
host = self._get_vault_secret(self._vault_database_secrets, 'host', 'HOST')
|
||||
port = self._get_vault_secret(self._vault_database_secrets, 'port', 'PORT') or str(self.DATABASE_PORT)
|
||||
user = self._get_vault_secret(self._vault_database_secrets, 'user', 'USER')
|
||||
password = self._get_vault_secret(self._vault_database_secrets, 'password', 'PASSWORD')
|
||||
name = self._get_vault_secret(self._vault_database_secrets, 'name', 'NAME', 'database', 'DATABASE')
|
||||
if not host or not user or not password or not name:
|
||||
h = (self.DATABASE_HOST or '').strip()
|
||||
u = (self.DATABASE_USER or '').strip()
|
||||
p = (self.DATABASE_PASSWORD or '').strip()
|
||||
n = (self.DATABASE_NAME or '').strip()
|
||||
if h and u and p and n:
|
||||
quoted_user = quote(u, safe='')
|
||||
quoted_password = quote(p, safe='')
|
||||
po = str(self.DATABASE_PORT)
|
||||
return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{h}:{po}/{n}'
|
||||
return ''
|
||||
quoted_user = quote(user, safe='')
|
||||
quoted_password = quote(password, safe='')
|
||||
return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{host}:{port}/{name}'
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f':{self.REDIS_PASSWORD}@' if self.REDIS_PASSWORD else ''
|
||||
return f'redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
vhost = '%2F' if self.RABBIT_VHOST == '/' else self.RABBIT_VHOST.lstrip('/')
|
||||
return f'amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}'
|
||||
|
||||
@property
|
||||
def EXCLUDED_PATHS(self) -> List[str]:
|
||||
return ['/docs', '/redoc', '/openapi.json', '/ping', '/health']
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
1
src/infrastructure/context_vars/__init__.py
Normal file
1
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.context_vars.trace_id import trace_id_var
|
||||
4
src/infrastructure/context_vars/trace_id.py
Normal file
4
src/infrastructure/context_vars/trace_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')
|
||||
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
1
src/infrastructure/database/decorators/__init__.py
Normal file
1
src/infrastructure/database/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.decorators.transactional import transactional
|
||||
15
src/infrastructure/database/decorators/transactional.py
Normal file
15
src/infrastructure/database/decorators/transactional.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, TypeVar, ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(method)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
async with self._unit_of_work:
|
||||
return await method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
6
src/infrastructure/database/models/__init__.py
Normal file
6
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.legal_entity import LegalEntityModel
|
||||
from src.infrastructure.database.models.purchase_request import PurchaseRequestModel
|
||||
|
||||
__all__ = ['Base', 'UserModel', 'LegalEntityModel', 'PurchaseRequestModel']
|
||||
19
src/infrastructure/database/models/base.py
Normal file
19
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
|
||||
for col in self.__table__.columns)
|
||||
return f"<{class_name}({attributes})>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
|
||||
for col in self.__table__.columns
|
||||
if getattr(self, col.name) is not None)
|
||||
return f"{class_name}({attributes})"
|
||||
32
src/infrastructure/database/models/legal_entity.py
Normal file
32
src/infrastructure/database/models/legal_entity.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import AuditTimestampsMixin, UlidPrimaryKeyMixin
|
||||
|
||||
|
||||
class LegalEntityModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = 'legal_entities'
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String(26), ForeignKey('users.id', ondelete='RESTRICT'), nullable=False, unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
short_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
inn: Mapped[str] = mapped_column(String(12), nullable=False, index=True)
|
||||
ogrn: Mapped[str | None] = mapped_column(String(15), nullable=True)
|
||||
kpp: Mapped[str | None] = mapped_column(String(9), nullable=True)
|
||||
legal_address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
actual_address: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
bank_details: Mapped[dict[str, Any] | None] = mapped_column(JSONB, nullable=True)
|
||||
contact_person: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
contact_phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default='active', default='active')
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='true', default=True)
|
||||
kyc_verified_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
created_by: Mapped[str | None] = mapped_column(String(26), nullable=True)
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
16
src/infrastructure/database/models/mixins/audit.py
Normal file
16
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
|
||||
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))
|
||||
33
src/infrastructure/database/models/purchase_request.py
Normal file
33
src/infrastructure/database/models/purchase_request.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Numeric, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import AuditTimestampsMixin, UlidPrimaryKeyMixin
|
||||
|
||||
|
||||
class PurchaseRequestModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = 'purchase_requests'
|
||||
|
||||
organization_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey('legal_entities.id', ondelete='RESTRICT'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default='submitted', default='submitted')
|
||||
usdt_amount: Mapped[Decimal] = mapped_column(Numeric(18, 8), nullable=False)
|
||||
rub_amount: Mapped[Decimal | None] = mapped_column(Numeric(18, 2), nullable=True)
|
||||
exchange_rate: Mapped[Decimal | None] = mapped_column(Numeric(18, 8), nullable=True)
|
||||
service_fee_percent: Mapped[Decimal | None] = mapped_column(Numeric(5, 2), nullable=True)
|
||||
comment: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
admin_comment: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
target_wallet_chain: Mapped[str | None] = mapped_column(String(16), nullable=True, server_default='ETH')
|
||||
target_wallet_address: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
tx_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
assigned_to: Mapped[str | None] = mapped_column(String(26), nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
50
src/infrastructure/database/models/sessions.py
Normal file
50
src/infrastructure/database/models/sessions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
from src.infrastructure.database.models import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
|
||||
|
||||
|
||||
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
sid: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
default=lambda: str(ULID()),
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
first_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
last_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
last_seen_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
|
||||
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
|
||||
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)
|
||||
31
src/infrastructure/database/models/user.py
Normal file
31
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Boolean, Date, DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
avatar_link: Mapped[str | None] = mapped_column(String(2048), nullable=True, default=None)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
account_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default='individual', default='individual')
|
||||
3
src/infrastructure/database/repositories/__init__.py
Normal file
3
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
from src.infrastructure.database.repositories.legal_entity_repository import LegalEntityRepository
|
||||
from src.infrastructure.database.repositories.purchase_request_repository import PurchaseRequestRepository
|
||||
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.application.abstractions.repositories.i_legal_entity_repository import ILegalEntityRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities.legal_entity import LegalEntityEntity
|
||||
from src.application.domain.exceptions import InternalException
|
||||
from src.infrastructure.database.models.legal_entity import LegalEntityModel
|
||||
|
||||
|
||||
class LegalEntityRepository(ILegalEntityRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
@staticmethod
|
||||
def _to_entity(model: LegalEntityModel) -> LegalEntityEntity:
|
||||
return LegalEntityEntity(
|
||||
id=model.id,
|
||||
user_id=model.user_id,
|
||||
name=model.name,
|
||||
inn=model.inn,
|
||||
status=model.status,
|
||||
short_name=model.short_name,
|
||||
ogrn=model.ogrn,
|
||||
kpp=model.kpp,
|
||||
legal_address=model.legal_address,
|
||||
actual_address=model.actual_address,
|
||||
bank_details=model.bank_details,
|
||||
contact_person=model.contact_person,
|
||||
contact_phone=model.contact_phone,
|
||||
kyc_verified=model.kyc_verified,
|
||||
kyc_verified_at=model.kyc_verified_at,
|
||||
)
|
||||
|
||||
async def get_by_user_id(self, user_id: str) -> LegalEntityEntity | None:
|
||||
try:
|
||||
stmt = select(LegalEntityModel).where(LegalEntityModel.user_id == user_id)
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return None
|
||||
return self._to_entity(model)
|
||||
except SQLAlchemyError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise InternalException(message=f'Database error: {exc}') from exc
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from ulid import ULID
|
||||
|
||||
from src.application.abstractions.repositories import IPurchaseRequestRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities.purchase_request import PurchaseRequestEntity
|
||||
from src.application.domain.exceptions import ApplicationException, InternalException, NotFoundException
|
||||
from src.infrastructure.database.models import PurchaseRequestModel
|
||||
|
||||
|
||||
class PurchaseRequestRepository(IPurchaseRequestRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
@staticmethod
|
||||
def _to_entity(model: PurchaseRequestModel) -> PurchaseRequestEntity:
|
||||
return PurchaseRequestEntity(
|
||||
id=model.id,
|
||||
organization_id=model.organization_id,
|
||||
status=model.status,
|
||||
usdt_amount=model.usdt_amount,
|
||||
rub_amount=model.rub_amount,
|
||||
exchange_rate=model.exchange_rate,
|
||||
service_fee_percent=model.service_fee_percent,
|
||||
comment=model.comment,
|
||||
admin_comment=model.admin_comment,
|
||||
target_wallet_chain=model.target_wallet_chain,
|
||||
target_wallet_address=model.target_wallet_address,
|
||||
tx_hash=model.tx_hash,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
completed_at=model.completed_at,
|
||||
)
|
||||
|
||||
def _apply_filters(self, stmt, *, organization_id: str, status: str | None):
|
||||
stmt = stmt.where(PurchaseRequestModel.organization_id == organization_id)
|
||||
if status:
|
||||
stmt = stmt.where(PurchaseRequestModel.status == status)
|
||||
return stmt
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
usdt_amount: Decimal,
|
||||
comment: str | None,
|
||||
target_wallet_chain: str | None,
|
||||
target_wallet_address: str | None,
|
||||
) -> PurchaseRequestEntity:
|
||||
try:
|
||||
model = PurchaseRequestModel(
|
||||
id=str(ULID()),
|
||||
organization_id=organization_id,
|
||||
status='submitted',
|
||||
usdt_amount=usdt_amount,
|
||||
comment=comment,
|
||||
target_wallet_chain=target_wallet_chain or 'ETH',
|
||||
target_wallet_address=target_wallet_address,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return self._to_entity(model)
|
||||
except SQLAlchemyError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise InternalException(message=f'Database error: {exc}') from exc
|
||||
|
||||
async def get_by_id(self, request_id: str) -> PurchaseRequestEntity:
|
||||
try:
|
||||
res = await self._session.execute(
|
||||
select(PurchaseRequestModel).where(PurchaseRequestModel.id == request_id)
|
||||
)
|
||||
model = res.scalar_one_or_none()
|
||||
if model is None:
|
||||
raise NotFoundException(message='Purchase request not found')
|
||||
return self._to_entity(model)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise InternalException(message=f'Database error: {exc}') from exc
|
||||
|
||||
async def list_by_organization(
|
||||
self,
|
||||
*,
|
||||
organization_id: str,
|
||||
status: str | None,
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> list[PurchaseRequestEntity]:
|
||||
try:
|
||||
stmt = select(PurchaseRequestModel).order_by(PurchaseRequestModel.created_at.desc())
|
||||
stmt = self._apply_filters(stmt, organization_id=organization_id, status=status)
|
||||
res = await self._session.execute(stmt.limit(limit).offset(offset))
|
||||
return [self._to_entity(model) for model in res.scalars().all()]
|
||||
except SQLAlchemyError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise InternalException(message=f'Database error: {exc}') from exc
|
||||
|
||||
async def count_by_organization(self, *, organization_id: str, status: str | None) -> int:
|
||||
try:
|
||||
stmt = select(func.count()).select_from(PurchaseRequestModel)
|
||||
stmt = self._apply_filters(stmt, organization_id=organization_id, status=status)
|
||||
res = await self._session.execute(stmt)
|
||||
return int(res.scalar_one())
|
||||
except SQLAlchemyError as exc:
|
||||
self._logger.exception(str(exc))
|
||||
raise InternalException(message=f'Database error: {exc}') from exc
|
||||
198
src/infrastructure/database/repositories/session_repository.py
Normal file
198
src/infrastructure/database/repositories/session_repository.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities import SessionEntity
|
||||
from src.application.abstractions.repositories import ISessionRepository
|
||||
from src.infrastructure.database.models import Session
|
||||
|
||||
|
||||
class SessionRepository(ISessionRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(select(Session).where(Session.sid == sid))
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
|
||||
if m is None:
|
||||
m = Session(
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
revoked_at=None,
|
||||
last_seen_at=now,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
first_ip=ip,
|
||||
last_ip=ip,
|
||||
)
|
||||
self._session.add(m)
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
else:
|
||||
m.sid = sid
|
||||
m.revoked_at = None
|
||||
m.last_seen_at = now
|
||||
m.refresh_jti_hash = refresh_jti_hash
|
||||
m.refresh_expires_at = refresh_expires_at
|
||||
m.user_agent = user_agent
|
||||
m.last_ip = ip
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
# Интерфейс требует None -> просто делаем update и flush
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(revoked_at=now)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
|
||||
values = {'last_seen_at': now}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
res = await self._session.execute(
|
||||
update(Session)
|
||||
.where(
|
||||
Session.sid == sid,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
|
||||
)
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
return (res.rowcount or 0) > 0
|
||||
58
src/infrastructure/database/repositories/user_repository.py
Normal file
58
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException, InternalException, NotFoundException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def _get_active_user(self, user_id: str) -> UserModel:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.id == user_id,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
user: UserModel | None = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
self._logger.warning(f'User not found with user_id {user_id}')
|
||||
raise NotFoundException(message='User not found')
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _to_entity(user: UserModel) -> UserEntity:
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
account_type=user.account_type,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
phone=user.phone,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> UserEntity:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
return self._to_entity(user)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise InternalException(message=f'Database error: {str(exception)}')
|
||||
61
src/infrastructure/database/unit_of_work.py
Normal file
61
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import (
|
||||
IUserRepository,
|
||||
ILegalEntityRepository,
|
||||
IPurchaseRequestRepository,
|
||||
)
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import (
|
||||
UserRepository,
|
||||
LegalEntityRepository,
|
||||
PurchaseRequestRepository,
|
||||
)
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._legal_entity_repository: ILegalEntityRepository = None
|
||||
self._purchase_request_repository: IPurchaseRequestRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._logger.debug('UnitOfWork enter')
|
||||
self._user_repository = None
|
||||
self._legal_entity_repository = None
|
||||
self._purchase_request_repository = None
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(f'UnitOfWork rollback_on_error exc_type={exc_type.__name__} exc_val={exc_val!r}')
|
||||
await self._session.rollback()
|
||||
self._logger.debug(f'UnitOfWork session rollback done exc_type={exc_type.__name__}')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('UnitOfWork commit')
|
||||
await self._session.close()
|
||||
self._logger.debug('UnitOfWork exit session closed')
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def legal_entity_repository(self) -> ILegalEntityRepository:
|
||||
if self._legal_entity_repository is None:
|
||||
self._legal_entity_repository = LegalEntityRepository(session=self._session, logger=self._logger)
|
||||
return self._legal_entity_repository
|
||||
|
||||
@property
|
||||
def purchase_request_repository(self) -> IPurchaseRequestRepository:
|
||||
if self._purchase_request_repository is None:
|
||||
self._purchase_request_repository = PurchaseRequestRepository(session=self._session, logger=self._logger)
|
||||
return self._purchase_request_repository
|
||||
28
src/infrastructure/logger/__init__.py
Normal file
28
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat
|
||||
from src.application.domain.enums import LogLevel
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
log_levels = {
|
||||
'DEBUG': LogLevel.DEBUG,
|
||||
'INFO': LogLevel.INFO,
|
||||
'WARNING': LogLevel.WARNING,
|
||||
'ERROR': LogLevel.ERROR,
|
||||
'CRITICAL': LogLevel.CRITICAL,
|
||||
'EXCEPTION': LogLevel.EXCEPTION,
|
||||
}
|
||||
|
||||
log_formats = {
|
||||
'JSON': LogFormat.JSON,
|
||||
'TEXT': LogFormat.TEXT,
|
||||
}
|
||||
|
||||
logger = Logger(
|
||||
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
|
||||
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
|
||||
)
|
||||
|
||||
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = json.dumps(log_data, ensure_ascii=False)
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
18
src/infrastructure/media/webp.py
Normal file
18
src/infrastructure/media/webp.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def image_bytes_to_webp(raw: bytes, *, quality: int = 82) -> bytes:
|
||||
im = Image.open(BytesIO(raw))
|
||||
if im.mode == 'P':
|
||||
im = im.convert('RGBA')
|
||||
elif im.mode == 'LA':
|
||||
im = im.convert('RGBA')
|
||||
elif im.mode not in ('RGBA', 'RGB'):
|
||||
im = im.convert('RGB')
|
||||
out = BytesIO()
|
||||
im.save(out, format='WEBP', quality=quality)
|
||||
return out.getvalue()
|
||||
1
src/infrastructure/messanger/__init__.py
Normal file
1
src/infrastructure/messanger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.messanger.rabbit_client import RabbitClient
|
||||
72
src/infrastructure/messanger/rabbit_client.py
Normal file
72
src/infrastructure/messanger/rabbit_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, Mapping
|
||||
from faststream.rabbit import RabbitBroker
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
class RabbitClient(IQueueMessanger):
|
||||
def __init__(self) -> None:
|
||||
self._broker = RabbitBroker(
|
||||
settings.RABBIT_URL,
|
||||
)
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._connected:
|
||||
return
|
||||
await self._broker.connect()
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._connected:
|
||||
return
|
||||
await self._broker.close()
|
||||
self._connected = False
|
||||
|
||||
async def _ensure_connected(self) -> None:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
queue=queue,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
exchange=exchange,
|
||||
routing_key=routing_key,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
3
src/infrastructure/security/__init__.py
Normal file
3
src/infrastructure/security/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.security.csrf import CsrfService
|
||||
from src.infrastructure.security.hash import HashService
|
||||
81
src/infrastructure/security/csrf.py
Normal file
81
src/infrastructure/security/csrf.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
import secrets
|
||||
from typing import Any, Optional, Mapping
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from src.application.contracts import ICsrfService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
|
||||
|
||||
class CsrfService(ICsrfService):
|
||||
COOKIE_NAME = 'csrf_token'
|
||||
HEADER_NAME = 'X-CSRF-Token'
|
||||
SALT = 'csrf'
|
||||
TTL_SECONDS = 3600
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._serializer = URLSafeTimedSerializer(
|
||||
secret_key=settings.CSRF_SECRET_KEY,
|
||||
salt=self.SALT,
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self) -> str:
|
||||
return self.COOKIE_NAME
|
||||
|
||||
@property
|
||||
def header_name(self) -> str:
|
||||
return self.HEADER_NAME
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> int:
|
||||
return self.TTL_SECONDS
|
||||
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
payload = {
|
||||
'sub': subject,
|
||||
'nonce': secrets.token_urlsafe(32),
|
||||
}
|
||||
return self._serializer.dumps(payload)
|
||||
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
try:
|
||||
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
|
||||
except SignatureExpired:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token expired',
|
||||
)
|
||||
except BadSignature:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token invalid',
|
||||
)
|
||||
|
||||
if expected_subject is not None and data.get('sub') != expected_subject:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token subject mismatch',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
cookie_token = cookies.get(self.COOKIE_NAME)
|
||||
header_token = headers.get(self.HEADER_NAME)
|
||||
return cookie_token, header_token
|
||||
|
||||
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
|
||||
if not cookie_token or not header_token:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token missing',
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token mismatch',
|
||||
)
|
||||
|
||||
self.verify(cookie_token, expected_subject=expected_subject)
|
||||
17
src/infrastructure/security/hash.py
Normal file
17
src/infrastructure/security/hash.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import bcrypt
|
||||
from src.application.contracts import IHashService, ILogger
|
||||
|
||||
|
||||
class HashService(IHashService):
|
||||
|
||||
def __init__(self, logger: ILogger):
|
||||
self._logger = logger
|
||||
|
||||
async def hash(self, value: str) -> str:
|
||||
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
|
||||
self._logger.info(f'Hash value {hashed_value.decode()}')
|
||||
return hashed_value.decode()
|
||||
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
self._logger.info(f'Hash value {hashed_value[:10]}')
|
||||
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())
|
||||
109
src/infrastructure/security/jwt.py
Normal file
109
src/infrastructure/security/jwt.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
class JwtService(IJwtService):
|
||||
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
|
||||
self._logger = logger
|
||||
self._key_store = key_store
|
||||
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'access':
|
||||
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return AccessTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
sid=str(payload['sid']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Access token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def _decode_and_verify(self, token: str) -> dict:
|
||||
kid: str | None = None
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
kid = header.get('kid')
|
||||
if not kid:
|
||||
self._logger.warning(f'JWT header missing kid header={header}')
|
||||
raise ApplicationException(status_code=401, message='Missing token header: kid')
|
||||
|
||||
received_alg = header.get('alg')
|
||||
if received_alg != settings.JWT_ALGORITHM:
|
||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token algorithm')
|
||||
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
|
||||
await self._key_store.refresh()
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Unknown token kid')
|
||||
|
||||
options = {
|
||||
'verify_signature': True,
|
||||
'verify_exp': True,
|
||||
'verify_nbf': True,
|
||||
'verify_iat': True,
|
||||
'verify_aud': bool(settings.JWT_AUDIENCE),
|
||||
'verify_iss': bool(settings.JWT_ISSUER),
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'require_sub': True,
|
||||
'leeway': 10,
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE or None,
|
||||
issuer=settings.JWT_ISSUER or None,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if 'sid' not in payload:
|
||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: sid')
|
||||
|
||||
if 'type' not in payload:
|
||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError as exception:
|
||||
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except JWTError as exception:
|
||||
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token')
|
||||
|
||||
except Exception as exception:
|
||||
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT decode failed')
|
||||
125
src/infrastructure/storage/s3_service.py
Normal file
125
src/infrastructure/storage/s3_service.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiobotocore.session import get_session
|
||||
|
||||
|
||||
class S3Service:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bucket: str,
|
||||
region: str,
|
||||
access_key_id: str | None,
|
||||
secret_access_key: str | None,
|
||||
public_base_url: str | None,
|
||||
endpoint_url: str | None,
|
||||
use_reg_ru_website_public_host: bool,
|
||||
):
|
||||
self._bucket = bucket
|
||||
self._region = region or 'us-east-1'
|
||||
self._access_key_id = access_key_id
|
||||
self._secret_access_key = secret_access_key
|
||||
pb = (public_base_url or '').strip().rstrip('/')
|
||||
self._public_base_url = pb if pb else None
|
||||
self._endpoint_url = endpoint_url.strip().rstrip('/') if endpoint_url and endpoint_url.strip() else None
|
||||
self._use_reg_ru_website_public_host = use_reg_ru_website_public_host
|
||||
|
||||
@staticmethod
|
||||
def _url_prefix_variants(prefix: str) -> list[str]:
|
||||
p = prefix.rstrip('/') + '/'
|
||||
out = [p]
|
||||
if p.startswith('https://'):
|
||||
out.append('http://' + p[8:])
|
||||
elif p.startswith('http://'):
|
||||
out.append('https://' + p[7:])
|
||||
return out
|
||||
|
||||
def _public_url_prefixes(self) -> list[str]:
|
||||
acc: list[str] = []
|
||||
pb = self._public_base_url
|
||||
if pb:
|
||||
acc.extend(self._url_prefix_variants(pb))
|
||||
ep = self._endpoint_url
|
||||
if ep:
|
||||
base = f'{ep.rstrip("/")}/{self._bucket}'
|
||||
acc.extend(self._url_prefix_variants(base))
|
||||
if ep and self._use_reg_ru_website_public_host and 's3.regru.cloud' in ep.lower():
|
||||
wh = f'https://{self._bucket}.website.regru.cloud'
|
||||
acc.extend(self._url_prefix_variants(wh))
|
||||
if not ep:
|
||||
if self._region == 'us-east-1':
|
||||
h = f'https://{self._bucket}.s3.amazonaws.com'
|
||||
else:
|
||||
h = f'https://{self._bucket}.s3.{self._region}.amazonaws.com'
|
||||
acc.extend(self._url_prefix_variants(h))
|
||||
seen: set[str] = set()
|
||||
uniq: list[str] = []
|
||||
for x in sorted(acc, key=len, reverse=True):
|
||||
if x not in seen:
|
||||
seen.add(x)
|
||||
uniq.append(x)
|
||||
return uniq
|
||||
|
||||
def object_key_from_public_url(self, url: str) -> str | None:
|
||||
u = (url or '').strip()
|
||||
if not u:
|
||||
return None
|
||||
for p in self._public_url_prefixes():
|
||||
if u.startswith(p):
|
||||
k = u[len(p):].split('?', 1)[0].split('#', 1)[0]
|
||||
return k if k else None
|
||||
return None
|
||||
|
||||
def _object_url(self, key: str) -> str:
|
||||
if self._public_base_url:
|
||||
return f'{self._public_base_url}/{key}'
|
||||
endpoint = self._endpoint_url
|
||||
if endpoint:
|
||||
if (
|
||||
self._use_reg_ru_website_public_host
|
||||
and 's3.regru.cloud' in endpoint.lower()
|
||||
):
|
||||
return f'https://{self._bucket}.website.regru.cloud/{key}'
|
||||
return f'{endpoint}/{self._bucket}/{key}'
|
||||
region = self._region
|
||||
if region == 'us-east-1':
|
||||
host = 's3.amazonaws.com'
|
||||
else:
|
||||
host = f's3.{region}.amazonaws.com'
|
||||
return f'https://{self._bucket}.{host}/{key}'
|
||||
|
||||
async def upload_bytes(self, *, key: str, body: bytes, content_type: str) -> str:
|
||||
session = get_session()
|
||||
kw: dict[str, object] = {'region_name': self._region}
|
||||
aid = self._access_key_id
|
||||
sk = self._secret_access_key
|
||||
ep = self._endpoint_url
|
||||
if aid:
|
||||
kw['aws_access_key_id'] = aid
|
||||
if sk:
|
||||
kw['aws_secret_access_key'] = sk
|
||||
if ep:
|
||||
kw['endpoint_url'] = ep
|
||||
async with session.create_client('s3', **kw) as client:
|
||||
await client.put_object(
|
||||
Bucket=self._bucket,
|
||||
Key=key,
|
||||
Body=body,
|
||||
ContentType=content_type,
|
||||
)
|
||||
return self._object_url(key)
|
||||
|
||||
async def delete_object(self, *, key: str) -> None:
|
||||
session = get_session()
|
||||
kw: dict[str, object] = {'region_name': self._region}
|
||||
aid = self._access_key_id
|
||||
sk = self._secret_access_key
|
||||
ep = self._endpoint_url
|
||||
if aid:
|
||||
kw['aws_access_key_id'] = aid
|
||||
if sk:
|
||||
kw['aws_secret_access_key'] = sk
|
||||
if ep:
|
||||
kw['endpoint_url'] = ep
|
||||
async with session.create_client('s3', **kw) as client:
|
||||
await client.delete_object(Bucket=self._bucket, Key=key)
|
||||
1
src/infrastructure/utils/__init__.py
Normal file
1
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.utils.instance_id import generate_instance_id
|
||||
14
src/infrastructure/utils/instance_id.py
Normal file
14
src/infrastructure/utils/instance_id.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
"""
|
||||
Generate a process-wide instance id in ULID format.
|
||||
|
||||
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
|
||||
"""
|
||||
|
||||
|
||||
return str(ULID())
|
||||
|
||||
|
||||
4
src/infrastructure/vault/__init__.py
Normal file
4
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from src.infrastructure.vault.client import VaultClient
|
||||
from src.infrastructure.vault.utils import create_hvac_client, read_kv2_secret
|
||||
from src.infrastructure.vault.keys import JwtKeyStore
|
||||
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
||||
66
src/infrastructure/vault/client.py
Normal file
66
src/infrastructure/vault/client.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import hvac
|
||||
|
||||
|
||||
def _vault_token_renew_failed(exception: Exception) -> bool:
|
||||
if isinstance(exception, (hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized)):
|
||||
return True
|
||||
message = getattr(exception, 'message', None) or str(exception)
|
||||
if isinstance(message, str):
|
||||
lower = message.lower()
|
||||
return 'permission denied' in lower or 'invalid token' in lower or '403' in lower
|
||||
return False
|
||||
|
||||
|
||||
class VaultClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
addr: str,
|
||||
role_id: str,
|
||||
secret_id: str,
|
||||
namespace: str | None,
|
||||
mount_point: str,
|
||||
) -> None:
|
||||
self._mount_point = mount_point
|
||||
self._addr = addr
|
||||
self._role_id = role_id
|
||||
self._secret_id = secret_id
|
||||
self._namespace = namespace
|
||||
self._client = hvac.Client(url=addr, namespace=namespace)
|
||||
self._approle_login()
|
||||
|
||||
def _approle_login(self) -> None:
|
||||
self._client.auth.approle.login(role_id=self._role_id, secret_id=self._secret_id)
|
||||
|
||||
def _renew_or_login(self) -> None:
|
||||
try:
|
||||
self._client.auth.token.renew_self()
|
||||
except Exception:
|
||||
self._approle_login()
|
||||
|
||||
def read_secret(self, path: str) -> dict[str, Any]:
|
||||
for attempt in range(2):
|
||||
try:
|
||||
secret = self._client.secrets.kv.v2.read_secret_version(
|
||||
path=path,
|
||||
mount_point=self._mount_point,
|
||||
)
|
||||
return dict(secret.get('data', {}).get('data', {}))
|
||||
except Exception as exc:
|
||||
if attempt == 0 and _vault_token_renew_failed(exc):
|
||||
self._renew_or_login()
|
||||
continue
|
||||
raise
|
||||
|
||||
def read_secret_optional(self, path: str) -> dict[str, Any]:
|
||||
if not path:
|
||||
return {}
|
||||
try:
|
||||
return self.read_secret(path)
|
||||
except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized):
|
||||
return {}
|
||||
111
src/infrastructure/vault/keys.py
Normal file
111
src/infrastructure/vault/keys.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.vault.client import VaultClient
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
|
||||
_instance: 'JwtKeyStore | None' = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str,
|
||||
vault_role_id: str,
|
||||
vault_secret_id: str,
|
||||
vault_namespace: str | None,
|
||||
mount_point: str,
|
||||
kid_path: str = 'jwt/kid',
|
||||
kids_prefix: str = 'jwt/kids',
|
||||
refresh_ttl_seconds: int = 60,
|
||||
):
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self._vault_client = VaultClient(
|
||||
addr=vault_addr,
|
||||
role_id=vault_role_id,
|
||||
secret_id=vault_secret_id,
|
||||
namespace=vault_namespace,
|
||||
mount_point=mount_point,
|
||||
)
|
||||
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
|
||||
self._refresh_ttl_seconds = refresh_ttl_seconds
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._keyset: JwtPublicKeySet | None = None
|
||||
self._last_refresh_at: datetime | None = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
if cls._instance is None:
|
||||
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
|
||||
return cls._instance
|
||||
|
||||
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
||||
kids = self._vault_client.read_secret(self._kid_path)
|
||||
active_kid = kids.get('active')
|
||||
previous_kid = kids.get('previous')
|
||||
|
||||
if not active_kid:
|
||||
raise RuntimeError('Vault jwt/kid secret missing "active"')
|
||||
|
||||
active = self._read_public_key_sync(str(active_kid))
|
||||
|
||||
previous = None
|
||||
if previous_kid and previous_kid != active_kid:
|
||||
previous = self._read_public_key_sync(str(previous_kid))
|
||||
|
||||
return JwtPublicKeySet(active=active, previous=previous)
|
||||
|
||||
def _read_public_key_sync(self, kid: str) -> JwtPublicKey:
|
||||
data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
|
||||
pub = data.get('public_key')
|
||||
if not pub:
|
||||
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
|
||||
return JwtPublicKey(kid=kid, public_key_pem=pub)
|
||||
|
||||
async def refresh(self) -> JwtPublicKeySet:
|
||||
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
||||
async with self._lock:
|
||||
self._keyset = keyset
|
||||
self._last_refresh_at = datetime.now(timezone.utc)
|
||||
return keyset
|
||||
|
||||
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.public_keys_by_kid().get(kid)
|
||||
|
||||
async def last_refresh_at(self) -> datetime | None:
|
||||
async with self._lock:
|
||||
return self._last_refresh_at
|
||||
|
||||
async def _get_or_refresh(self) -> JwtPublicKeySet:
|
||||
async with self._lock:
|
||||
ks = self._keyset
|
||||
last = self._last_refresh_at
|
||||
|
||||
if ks is None:
|
||||
return await self.refresh()
|
||||
|
||||
if last is None:
|
||||
return await self.refresh()
|
||||
|
||||
age = (datetime.now(timezone.utc) - last).total_seconds()
|
||||
if age >= self._refresh_ttl_seconds:
|
||||
return await self.refresh()
|
||||
|
||||
return ks
|
||||
23
src/infrastructure/vault/scheduler.py
Normal file
23
src/infrastructure/vault/scheduler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
store.refresh,
|
||||
trigger=IntervalTrigger(seconds=refresh_seconds),
|
||||
id="jwt_keys_refresh",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
|
||||
return scheduler
|
||||
17
src/infrastructure/vault/utils.py
Normal file
17
src/infrastructure/vault/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import hvac
|
||||
|
||||
|
||||
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
|
||||
client = hvac.Client(url=url, token=token, timeout=timeout)
|
||||
if not client.is_authenticated():
|
||||
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
|
||||
return client
|
||||
|
||||
|
||||
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point=mount_point,
|
||||
path=path,
|
||||
)
|
||||
return secret["data"]["data"]
|
||||
140
src/main.py
Normal file
140
src/main.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.exceptions import HTTPException
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from src.application.domain.exceptions import ApplicationException, UnauthorizedException
|
||||
from src.infrastructure.cache import create_redis_client
|
||||
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
|
||||
from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.handlers import (
|
||||
application_exception_handler,
|
||||
http_exception_handler,
|
||||
unhandled_exception_handler,
|
||||
validation_exception_handler,
|
||||
)
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
from src.presentation.routing import purchase_requests_router
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
|
||||
if not (user_ok and pass_ok):
|
||||
raise UnauthorizedException(
|
||||
message='Unauthorized',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
instance_id = generate_instance_id()
|
||||
logger.set_instance_id(instance_id)
|
||||
logger.info(f'B2B service instance started with id {instance_id}')
|
||||
|
||||
app.state.redis = create_redis_client()
|
||||
|
||||
if not settings.VAULT_ROLE_ID.strip() or not settings.VAULT_SECRET_ID.strip():
|
||||
raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID must be set')
|
||||
|
||||
jwt_store = JwtKeyStore(
|
||||
vault_addr=settings.VAULT_ADDR,
|
||||
vault_role_id=settings.VAULT_ROLE_ID,
|
||||
vault_secret_id=settings.VAULT_SECRET_ID,
|
||||
vault_namespace=settings.VAULT_NAMESPACE,
|
||||
mount_point=settings.VAULT_MOUNT_POINT,
|
||||
kid_path=settings.VAULT_JWT_KID_PATH,
|
||||
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
||||
)
|
||||
|
||||
await jwt_store.refresh()
|
||||
|
||||
jwt_scheduler = start_jwt_keys_scheduler(jwt_store, refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS)
|
||||
|
||||
app.state.jwt_key_store = jwt_store
|
||||
app.state.jwt_keys_scheduler = jwt_scheduler
|
||||
yield
|
||||
await app.state.redis.aclose()
|
||||
logger.info(f'B2B service instance ended with id {instance_id}')
|
||||
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=None,
|
||||
lifespan=lifespan,
|
||||
title='B2B Service',
|
||||
version='1.0.0',
|
||||
description='Purchase requests API for legal entity client users.',
|
||||
license_info={
|
||||
'name': 'MIT',
|
||||
'url': 'https://opensource.org/licenses/MIT',
|
||||
},
|
||||
)
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
app.add_exception_handler(HTTPException, http_exception_handler)
|
||||
app.add_exception_handler(ApplicationException, application_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
app.include_router(purchase_requests_router)
|
||||
|
||||
app.add_middleware(TraceIDMiddleware, logger=logger)
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts=True,
|
||||
hsts_preload=False,
|
||||
frame_options='DENY',
|
||||
referrer_policy='strict-origin-when-cross-origin',
|
||||
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[],
|
||||
allow_origin_regex='.*',
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
|
||||
@app.get('/docs', include_in_schema=False)
|
||||
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom Swagger documentation, optionally protected with basic authentication.'''
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
|
||||
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
|
||||
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
|
||||
)
|
||||
|
||||
|
||||
@app.get('/redoc', include_in_schema=False)
|
||||
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
|
||||
return get_redoc_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
|
||||
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
|
||||
)
|
||||
|
||||
|
||||
@app.post('/ping')
|
||||
async def ping() -> dict[str, str]:
|
||||
return {
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
4
src/presentation/decorators/__init__.py
Normal file
4
src/presentation/decorators/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from src.presentation.decorators.csrf import csrf_protect
|
||||
from src.presentation.decorators.rate_limit import rate_limit, _email_rl_key as email_rl_key
|
||||
from src.presentation.decorators.auth import require_access_token
|
||||
from src.presentation.decorators.cache import cached
|
||||
36
src/presentation/decorators/auth.py
Normal file
36
src/presentation/decorators/auth.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from src.application.contracts import IJwtService
|
||||
from src.application.domain.exceptions import UnauthorizedException
|
||||
from src.application.domain.dto import AccessTokenPayload, AuthContext
|
||||
from src.presentation.dependencies import get_jwt_service
|
||||
|
||||
|
||||
def _extract_access_token(request: Request) -> str | None:
|
||||
token = request.cookies.get('access_token')
|
||||
|
||||
if token:
|
||||
return token
|
||||
|
||||
auth = request.headers.get('Authorization')
|
||||
if auth:
|
||||
scheme, param = get_authorization_scheme_param(auth)
|
||||
if scheme.lower() == 'bearer' and param:
|
||||
return param
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
request: Request,
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
) -> AuthContext:
|
||||
token = _extract_access_token(request)
|
||||
if not token:
|
||||
raise UnauthorizedException(message='Not authenticated')
|
||||
|
||||
payload: AccessTokenPayload = await jwt_service.decode_access_token(token)
|
||||
if payload.type != 'access':
|
||||
raise UnauthorizedException(message='Invalid token type')
|
||||
|
||||
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)
|
||||
46
src/presentation/decorators/cache.py
Normal file
46
src/presentation/decorators/cache.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Any, Awaitable, Callable
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.infrastructure.cache import KeydbCache
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies.cache import get_redis
|
||||
|
||||
|
||||
def cached(*, prefix: str) -> Callable:
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
logger = get_logger()
|
||||
|
||||
request = kwargs.get('request')
|
||||
if not isinstance(request, Request):
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
request = a
|
||||
break
|
||||
|
||||
auth = kwargs.get('auth')
|
||||
user_id = getattr(auth, 'user_id', None) if auth else None
|
||||
|
||||
if request is None or user_id is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
cache_key = f'{prefix}:{user_id}'
|
||||
|
||||
try:
|
||||
redis = get_redis(request)
|
||||
cache = KeydbCache(redis)
|
||||
hit = await cache.get_user(user_id)
|
||||
if hit is not None:
|
||||
logger.debug(f'Cache hit key={cache_key}')
|
||||
return ORJSONResponse(status_code=200, content=hit)
|
||||
except Exception as e:
|
||||
logger.warning(f'Cache read failed key={cache_key} error={e}')
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user