Initial commit

This commit is contained in:
2026-04-12 09:16:16 +03:00
commit 5fe8efc5d4
98 changed files with 5351 additions and 0 deletions

145
.gitignore vendored Normal file
View File

@@ -0,0 +1,145 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
*.pyd
*.dll
# Distribution / packaging
.Python
build/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache/
.pytest_cache/
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
# Type checkers / linters
.mypy_cache/
.dmypy.json
dmypy.json
.pyre/
.pytype/
.ruff_cache/
# Jupyter Notebook
.ipynb_checkpoints/
# Environments
.env
.env.*
.venv/
venv/
ENV/
env/
env.bak/
venv.bak/
# Poetry
poetry.lock
# Pipenv
Pipfile.lock
# Hatch
.hatch/
# pyenv
.python-version
# Logs
*.log
logs/
# Local databases
*.sqlite3
*.db
# Secrets / credentials
secrets.json
credentials.json
*.pem
*.key
*.crt
# OS generated files
.DS_Store
Thumbs.db
Desktop.ini
# PyCharm / IntelliJ IDEA
.idea/
*.iml
out/
# VS Code (optional)
.vscode/
# Temporary files
*.tmp
*.temp
*.swp
*.swo
*~
# Sphinx docs
docs/_build/
# mkdocs
site/
# celery
celerybeat-schedule
celerybeat.pid
# mypy compiled cache
.mypy_cache/
# pyinstaller
*.manifest
*.spec
# pytest debug
pytestdebug.log
# Local config overrides
config.local.py
settings.local.py
# Vault / local dev secrets
.env.vault
vault.token
.env
.dockerignore
/sql

28
Dockerfile Normal file
View File

@@ -0,0 +1,28 @@
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS builder
WORKDIR /app
# Install dependencies (cached layer)
COPY pyproject.toml uv.lock ./
RUN uv sync --frozen --no-dev
# Copy source last (fast rebuilds)
COPY src ./src
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS runtime
WORKDIR /app
# Use the virtualenv created by `uv sync` in builder
COPY --from=builder /app/.venv /app/.venv
COPY --from=builder /app/src /app/src
ENV PATH="/app/.venv/bin:$PATH" \
PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PYTHONPATH=/app
EXPOSE 8000
CMD ["sh", "-c", "granian --interface asgi ${APP_MODULE:-src.main:app} --host ${APP_HOST:-0.0.0.0} --port ${APP_PORT:-8000} --workers ${APP_WORKERS:-1} --loop uvloop"]

83
docker-compose.yml Normal file
View File

@@ -0,0 +1,83 @@
services:
auth:
container_name: auth-service
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
PYTHONUNBUFFERED: "1"
APP_MODULE: "src.main:app"
APP_HOST: "0.0.0.0"
APP_PORT: "8000"
APP_WORKERS: "1"
env_file:
- .env
depends_on:
keydb:
condition: service_healthy
restart: no
keydb:
image: eqalpha/keydb
container_name: keydb
restart: no
expose:
- "6379"
volumes:
- keydb_data:/data
command:
- keydb-server
- --requirepass
- keydb
- --dir
- /data
- --appendonly
- "yes"
- --appendfsync
- everysec
- --save
- "900"
- "1"
- --save
- "300"
- "10"
- --save
- "60"
- "10000"
healthcheck:
test: [ "CMD", "redis-cli", "-a", "keydb", "ping" ]
interval: 5s
timeout: 2s
retries: 20
# keydb:
# image: eqalpha/keydb
# container_name: keydb
# restart: no
# expose:
# - "6379"
# volumes:
# - keydb_data:/data
# environment:
# KEYDB_PASSWORD: keydb
# command: >
# sh -c "
# keydb-server
# --requirepass $$KEYDB_PASSWORD
# --dir /data
# --appendonly yes
# --appendfsync everysec
# --save 900 1
# --save 300 10
# --save 60 10000
# "
# healthcheck:
# test: ["CMD", "redis-cli", "ping"]
# interval: 5s
# timeout: 2s
# retries: 20
volumes:
keydb_data:

24
pyproject.toml Normal file
View File

@@ -0,0 +1,24 @@
[project]
name = "bitok"
version = "0.1.0"
description = "Add your description here"
requires-python = "==3.12.*"
dependencies = [
"apscheduler==3.11.2",
"asyncpg==0.31.0",
"bcrypt==5.0.0",
"dotenv==0.9.9",
"email-validator==2.3.0",
"fastapi==0.128.7",
"faststream[rabbit]==0.6.6",
"granian==2.6.1",
"hvac==2.4.0",
"itsdangerous==2.2.0",
"orjson==3.11.7",
"pydantic-settings==2.12.0",
"python-jose==3.5.0",
"python-ulid==3.1.0",
"redis==7.2.0",
"sqlalchemy==2.0.46",
"uvloop==0.22.1; platform_system != 'Windows'",
]

View File

@@ -0,0 +1 @@
from src.application.abstractions.i_unit_of_work import IUnitOfWork

View File

@@ -0,0 +1,19 @@
from __future__ import annotations
from typing import Protocol, runtime_checkable
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
@runtime_checkable
class IUnitOfWork(Protocol):
async def __aenter__(self) -> "IUnitOfWork": ...
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
@property
def user_repository(self) -> IUserRepository: ...
@property
def session_repository(self) -> ISessionRepository: ...

View File

@@ -0,0 +1,2 @@
from src.application.abstractions.repositories.i_user_repository import IUserRepository
from src.application.abstractions.repositories.i_session_repository import ISessionRepository

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from src.application.domain.entities import SessionEntity
class ISessionRepository(ABC):
@abstractmethod
async def get_by_sid(self, sid: str) -> SessionEntity | None:
raise NotImplementedError
@abstractmethod
async def get_by_user_device(self, user_id: str, device_id: str) -> SessionEntity | None:
raise NotImplementedError
@abstractmethod
async def upsert_by_device(
self,
user_id: str,
device_id: str,
sid: str,
refresh_jti_hash: str,
refresh_expires_at: datetime,
user_agent: str | None,
ip: str | None,
now: datetime,
) -> SessionEntity:
raise NotImplementedError
@abstractmethod
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
raise NotImplementedError
@abstractmethod
async def rotate_refresh(
self,
sid: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> None:
raise NotImplementedError
@abstractmethod
async def rotate_refresh_if_match(
self,
*,
sid: str,
old_jti_hash: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,19 @@
from abc import ABC
from abc import abstractmethod
from src.application.domain.entities import UserEntity
class IUserRepository(ABC):
@abstractmethod
async def create_user(self, email: str, password_hash: str) -> UserEntity:
raise NotImplementedError
@abstractmethod
async def get_user_by_email(self, email: str) -> UserEntity:
raise NotImplementedError
@abstractmethod
async def exists_by_email(self, email: str) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,6 @@
from src.application.commands.user_registration_complete import UserRegistrationCompleteCommand
from src.application.commands.user_login_complete import UserLoginCompleteCommand
from src.application.commands.user_logout import UserLogoutCommand
from src.application.commands.jwt_refresh import JwtRefreshCommand
from src.application.commands.user_registration_start import UserRegistrationStartCommand
from src.application.commands.user_login_start import UserLoginStartCommand

View File

@@ -0,0 +1,70 @@
from datetime import datetime, timezone, timedelta
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
class JwtRefreshCommand:
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service
self._logger = logger
@transactional
async def __call__(self, *, refresh_token: str, ip: str | None, user_agent: str | None) -> tuple[str, str]:
now = datetime.now(timezone.utc)
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
sid = payload.sid
user_id = payload.sub
jti = payload.jti
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
if sess is None:
raise ApplicationException(status_code=401, message='Session not found')
if sess.revoked_at is not None:
raise ApplicationException(status_code=401, message='Session revoked')
if sess.refresh_expires_at <= now:
raise ApplicationException(status_code=401, message='Session expired')
if str(sess.user_id) != str(user_id):
raise ApplicationException(status_code=401, message='Invalid session subject')
ok = await self._hash_service.verify(
plain_value=jti,
hashed_value=sess.refresh_jti_hash,
)
if not ok:
await self._unit_of_work.session_repository.revoke_by_sid(sid=sid, now=now)
raise ApplicationException(status_code=401, message='Refresh token reuse detected')
new_jti = str(ULID())
new_jti_hash = await self._hash_service.hash(value=new_jti)
new_refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
rotated = await self._unit_of_work.session_repository.rotate_refresh_if_match(
sid=sid,
old_jti_hash=sess.refresh_jti_hash,
new_jti_hash=new_jti_hash,
new_refresh_expires_at=new_refresh_expires_at,
now=now,
ip=ip,
user_agent=user_agent,
)
if not rotated:
raise ApplicationException(status_code=401, message='Refresh already rotated')
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)
self._logger.info(f'Tokens refreshed (user_id={user_id}, sid={sid})')
return access, refresh

View File

@@ -0,0 +1,117 @@
from datetime import timedelta, datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import UserLoginDto
from src.application.domain.entities import UserEntity
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
class UserLoginCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service
self._cache = cache
self._logger = logger
@transactional
async def __call__(
self,
*,
email: str,
password: str,
code: str,
device_id: str,
user_agent: str | None,
ip: str | None,
) -> UserLoginDto:
email = (email or '').strip().lower()
code = (code or '').strip()
code_key = f'login:code:{code}'
email_key = f'login:email:{email}'
cached_email = await self._cache.get(code_key)
if not cached_email:
self._logger.info(f'Login failed: code not found (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_email != email:
self._logger.info(f'Login failed: code-email mismatch (email={email}, cached_email={cached_email})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(email_key)
if not code_hash:
self._logger.info(f'Login failed: email key missing (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
if not ok:
self._logger.info(f'Login failed: code hash mismatch (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
now = datetime.now(timezone.utc)
user: UserEntity = await self._unit_of_work.user_repository.get_user_by_email(email=email)
ok = await self._hash_service.verify(plain_value=password, hashed_value=user.password_hash)
if not ok:
self._logger.warning(f'{user.id} login failed: invalid credentials')
raise ApplicationException(status_code=401, message='Invalid credentials')
try:
await self._cache.delete(code_key)
await self._cache.delete(email_key)
except Exception as e:
self._logger.warning(f'Login cleanup failed (email={email}): {e}')
sid = str(ULID())
jti = str(ULID())
refresh_jti_hash = await self._hash_service.hash(value=jti)
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
await self._unit_of_work.session_repository.upsert_by_device(
user_id=user.id,
device_id=device_id,
sid=sid,
refresh_jti_hash=refresh_jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
ip=ip,
now=now,
)
access_token = await self._jwt_service.create_access_token(user_id=user.id, sid=sid)
refresh_token = await self._jwt_service.create_refresh_token(user_id=user.id, sid=sid, refresh_jti=jti)
return UserLoginDto(
id=user.id,
email=user.email,
first_name=user.first_name,
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet,
phone=user.phone,
bik=user.bik,
account_number=user.account_number,
card_number=user.card_number,
inn=user.inn,
kyc_verified=user.kyc_verified,
kyc_verified_at=user.kyc_verified_at,
created_at=user.created_at,
updated_at=user.updated_at,
access_token=access_token,
refresh_token=refresh_token,
)

View File

@@ -0,0 +1,133 @@
import secrets
from datetime import timezone, datetime
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class UserLoginStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, email: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
EMAIL_PREFIX = 'login:email:'
CODE_PREFIX = 'login:code:'
LOCK_PREFIX = 'login:lock:'
email = (email or '').strip().lower()
if not email:
self._logger.info('Login start failed: empty email')
raise ApplicationException(400, 'Invalid email')
exists = await self._unit_of_work.user_repository.exists_by_email(email)
if not exists:
self._logger.info(f'Login failed: email already registered ({email})')
raise ApplicationException(404, 'Email registered')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{email}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Login start throttled by lock ({email})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
email_key = f'{EMAIL_PREFIX}{email}'
existing = await self._cache.get(email_key)
if existing:
self._logger.info(f'Login start denied: code already exists for {email}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, email, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Login start failed: cannot save code hash for {email}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'auth-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'login',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'payload: {payload})')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(email_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for {email}: {str(rollback_err)}')
self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'login code created for {email}')
return True
self._logger.error(f'login start failed: code space exhausted for {email}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from datetime import datetime, timezone
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ILogger, IJwtService
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.database.decorators import transactional
class UserLogoutCommand:
def __init__(self, unit_of_work: IUnitOfWork, jwt_service: IJwtService, logger: ILogger):
self._unit_of_work = unit_of_work
self._jwt_service = jwt_service
self._logger = logger
@transactional
async def __call__(self, *, refresh_token: str | None) -> None:
if not refresh_token:
return
try:
payload: RefreshTokenPayload = self._jwt_service.decode_refresh_token(refresh_token)
except ApplicationException:
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
return
now = datetime.now(timezone.utc)
await self._unit_of_work.session_repository.revoke_by_sid(sid=payload.sid, now=now)
self._logger.info(f'Logout: session revoked (sid={payload.sid}, user_id={payload.sub})')

View File

@@ -0,0 +1,121 @@
from datetime import timedelta, datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import UserCreatedDto
from src.application.domain.entities import UserEntity
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
class UserRegistrationCompleteCommand:
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._cache = cache
self._hash_service = hash_service
self._jwt_service = jwt_service
self._logger = logger
@transactional
async def __call__(
self,
*,
email: str,
password: str,
device_id: str,
code: str,
user_agent: str | None,
ip: str | None,
) -> UserCreatedDto:
email = (email or '').strip().lower()
code = (code or '').strip()
code_key = f'reg:code:{code}'
email_key = f'reg:email:{email}'
cached_email = await self._cache.get(code_key)
if not cached_email:
self._logger.info(f'Registration failed: code not found (email={email}, code={code})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_email != email:
self._logger.info(f'Registration failed: code-email mismatch (email={email}, cached_email={cached_email})')
raise ApplicationException(400, 'Invalid or expired code')
code_hash = await self._cache.get(email_key)
if not code_hash:
self._logger.info(f'Registration failed: email key missing (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
ok = await self._hash_service.verify(
hashed_value=code_hash,
plain_value=code,
)
if not ok:
self._logger.info(f'Registration failed: code hash mismatch (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
deleted_code = await self._cache.delete(code_key)
deleted_email = await self._cache.delete(email_key)
if not deleted_code or not deleted_email:
self._logger.info(
f'Registration cleanup: keys already missing '
f'(email={email}, deleted_code={deleted_code}, deleted_email={deleted_email})'
)
now = datetime.now(timezone.utc)
password_hash = await self._hash_service.hash(value=password)
user: UserEntity = await self._unit_of_work.user_repository.create_user(
email=email,
password_hash=password_hash,
)
sid = str(ULID())
jti = str(ULID())
refresh_jti_hash = await self._hash_service.hash(value=jti)
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
await self._unit_of_work.session_repository.upsert_by_device(
user_id=user.id,
device_id=device_id,
sid=sid,
refresh_jti_hash=refresh_jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
ip=ip,
now=now,
)
access_token = await self._jwt_service.create_access_token(
user_id=user.id,
sid=sid,
)
refresh_token = await (
self._jwt_service.create_refresh_token(
user_id=user.id,
sid=sid,
refresh_jti=jti,
))
self._logger.info(f'User registered successfully user_id={user.id} device_id={device_id} sid={sid}')
return UserCreatedDto(
id=user.id,
email=user.email,
access_token=access_token,
refresh_token=refresh_token,
)

View File

@@ -0,0 +1,133 @@
import secrets
from datetime import datetime, timezone
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, ILogger, ICache, IQueueMessanger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
from src.infrastructure.database.decorators import transactional
class UserRegistrationStartCommand:
def __init__(
self,
hash_service: IHashService,
cache: ICache,
unit_of_work: IUnitOfWork,
logger: ILogger,
messanger: IQueueMessanger,
):
self._hash_service = hash_service
self._unit_of_work = unit_of_work
self._cache = cache
self._logger = logger
self._messanger = messanger
@transactional
async def __call__(self, email: str) -> bool:
TTL = 300
LOCK_TTL = 30
MAX_ATTEMPTS = 20
EMAIL_PREFIX = 'reg:email:'
CODE_PREFIX = 'reg:code:'
LOCK_PREFIX = 'reg:lock:'
email = (email or '').strip().lower()
if not email:
self._logger.info('Registration start failed: empty email')
raise ApplicationException(400, 'Invalid email')
exists = await self._unit_of_work.user_repository.exists_by_email(email)
if exists:
self._logger.info(f'Registration failed: email already registered ({email})')
raise ApplicationException(409, 'Email already registered')
trace_id = trace_id_var.get()
if not trace_id or trace_id == 'N/A':
trace_id = None
lock_key = f'{LOCK_PREFIX}{email}'
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
if not locked:
self._logger.info(f'Registration start throttled by lock ({email})')
raise ApplicationException(429, 'Too many requests. Please wait.')
try:
email_key = f'{EMAIL_PREFIX}{email}'
existing = await self._cache.get(email_key)
if existing:
self._logger.info(f'Registration start denied: code already exists for {email}')
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
for _ in range(MAX_ATTEMPTS):
code = f'{secrets.randbelow(1_000_000):06d}'
code_key = f'{CODE_PREFIX}{code}'
code_hash = await self._hash_service.hash(code)
reserved = await self._cache.set_nx(code_key, email, ttl=TTL)
if not reserved:
continue
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
if not saved:
await self._cache.delete(code_key)
self._logger.error(f'Registration start failed: cannot save code hash for {email}')
raise ApplicationException(503, 'Temporary error. Please try again.')
message_id = str(ULID())
now = datetime.now(timezone.utc).isoformat()
metadata = {
'trace_id': trace_id,
'source': 'auth-service',
'timestamp': now,
'message_id': message_id,
}
payload = {
'email': email,
'code': code,
'ttl_seconds': TTL,
}
message = {
'event': 'registration',
'payload': payload,
'metadata': metadata,
}
self._logger.info(f'payload: {payload})')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
message=message,
persist=True,
correlation_id=trace_id,
message_id=message_id,
headers={'trace_id': trace_id} if trace_id else None,
)
except Exception as exception:
try:
await self._cache.delete(email_key)
await self._cache.delete(code_key)
except Exception as rollback_err:
self._logger.error(f'Publish failed and rollback cache failed for {email}: {str(rollback_err)}')
self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'Registration code created for {email}')
return True
self._logger.error(f'Registration start failed: code space exhausted for {email}')
raise ApplicationException(503, 'Temporary error. Please try again.')
finally:
await self._cache.delete(lock_key)

View File

@@ -0,0 +1,7 @@
from src.application.contracts.i_hash_service import IHashService
from src.application.contracts.i_logger import ILogger
from src.application.contracts.i_user_service import IUserService
from src.application.contracts.i_jwt_service import IJwtService
from src.application.contracts.i_csrf_service import ICsrfService
from src.application.contracts.i_cache import ICache
from src.application.contracts.i_queue_messanger import IQueueMessanger

View File

@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
class ICache(ABC):
@abstractmethod
async def set(self, key: str, value: str, ttl: int) -> bool:
raise NotImplementedError
@abstractmethod
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
raise NotImplementedError
@abstractmethod
async def get(self, key: str) -> str | None:
raise NotImplementedError
@abstractmethod
async def delete(self, key: str) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,26 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Optional, Mapping
class ICsrfService(ABC):
@abstractmethod
def issue(self, subject: Optional[str] = None) -> str:
raise NotImplementedError
@abstractmethod
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
raise NotImplementedError
@abstractmethod
def verify_pair(
self,
cookie_token: Optional[str],
header_token: Optional[str],
expected_subject: Optional[str] = None,
) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
class IHashService(ABC):
@abstractmethod
async def hash(self, value: str) -> str:
raise NotImplementedError
@abstractmethod
async def verify(self, hashed_value: str, plain_value: str) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,22 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
class IJwtService(ABC):
@abstractmethod
async def create_access_token(self, user_id: str, sid: str) -> str:
raise NotImplementedError
@abstractmethod
async def create_refresh_token(self, user_id: str, sid: str, refresh_jti: str) -> str:
raise NotImplementedError
@abstractmethod
async def decode_access_token(self, token: str) -> AccessTokenPayload:
raise NotImplementedError
@abstractmethod
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
raise NotImplementedError

View File

@@ -0,0 +1,68 @@
from typing import Protocol, Optional, Callable
from src.application.domain.enums.log_format import LogFormat
from src.application.domain.enums.log_level import LogLevel
class ILogger(Protocol):
"""Interface for synchronous logger with ContextVar support for trace_id."""
log_format: LogFormat
min_level: LogLevel
id_generator: Optional[Callable[[], str]]
instance_id: str
def set_format(self, log_format: LogFormat) -> None:
"""Set log format using LogFormat enum"""
...
def set_min_level(self, level: LogLevel) -> None:
"""Set minimum log level"""
...
def new_trace_id(self) -> str:
"""Create and set new trace_id in context"""
...
def set_trace_id(self, trace_id: str) -> None:
"""Set existing trace_id in context"""
...
def get_trace_id(self) -> str:
"""Get current trace_id from context"""
...
def clear_trace_id(self) -> None:
"""Clear the trace_id in the context"""
...
def set_instance_id(self, instance_id: str) -> None:
"""Set service instance id (ULID recommended)"""
...
def get_instance_id(self) -> str:
"""Get current service instance id"""
...
def debug(self, message: str) -> None:
"""Log debug message"""
...
def info(self, message: str) -> None:
"""Log info message"""
...
def warning(self, message: str) -> None:
"""Log warning message"""
...
def error(self, message: str) -> None:
"""Log error message"""
...
def critical(self, message: str) -> None:
"""Log critical message"""
...
def exception(self, message: str) -> None:
"""Log exception with traceback"""
...

View File

@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from typing import Mapping, Any
class IQueueMessanger(ABC):
@abstractmethod
async def connect(self) -> None:
raise NotImplementedError
@abstractmethod
async def close(self) -> None:
raise NotImplementedError
@abstractmethod
async def publish_to_queue(
self,
queue: str,
message: Any,
*,
persist: bool = True,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
async def publish(
self,
message: Any,
*,
exchange: str,
routing_key: str,
persist: bool = True,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from src.application.domain.dto import UserCreatedDto, UserLoginDto
class IUserService(ABC):
@abstractmethod
async def registration(self, email: str, password: str) -> UserCreatedDto:
raise NotImplementedError
@abstractmethod
async def login(self, email: str, password: str) -> UserLoginDto:
raise NotImplementedError

View File

@@ -0,0 +1,3 @@
from src.application.domain.dto.user import UserCreatedDto, UserLoginDto
from src.application.domain.dto.token import AccessTokenPayload, RefreshTokenPayload, AuthContext
from src.application.domain.dto.keys import JwtKeySet, JwtKeyPair

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import Optional, Dict
@dataclass(frozen=True)
class JwtKeyPair:
kid: str
private_key_pem: str
public_key_pem: str
@dataclass(frozen=True)
class JwtKeySet:
active: JwtKeyPair
previous: Optional[JwtKeyPair] = None
def public_keys_by_kid(self) -> Dict[str, str]:
out = {self.active.kid: self.active.public_key_pem}
if self.previous:
out[self.previous.kid] = self.previous.public_key_pem
return out

View File

@@ -0,0 +1,30 @@
from pydantic import BaseModel
class AccessTokenPayload(BaseModel):
sub: str
type: str
sid: str
iat: int
nbf: int
exp: int
iss: str | None = None
aud: str | None = None
class RefreshTokenPayload(BaseModel):
sub: str
type: str
sid: str
jti: str
iat: int
nbf: int
exp: int
iss: str | None = None
aud: str | None = None
class AuthContext(BaseModel):
user_id: str
sid: str
token: AccessTokenPayload

View File

@@ -0,0 +1,33 @@
from dataclasses import dataclass
from datetime import datetime, date
@dataclass(slots=True)
class UserCreatedDto:
id: str
email: str
access_token: str
refresh_token: str
@dataclass(slots=True)
class UserLoginDto:
id: str | None = None
email: str | None = None
first_name: str | None = None
middle_name: str | None = None
last_name: str | None = None
birth_date: date | None = None
crypto_wallet: str | None = None
phone: str | None = None
bik: str | None = None
account_number: str | None = None
card_number: str | None = None
inn: str | None = None
kyc_verified: bool | None = None
access_token: str | None = None
refresh_token: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
kyc_verified_at: datetime | None = None

View File

@@ -0,0 +1,5 @@
from src.application.domain.entities.user import UserEntity
from src.application.domain.entities.session import SessionEntity
__all__ = ['UserEntity', 'SessionEntity']

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
@dataclass(slots=True)
class SessionEntity:
sid: str
user_id: str
device_id: str
revoked_at: datetime | None
last_seen_at: datetime
refresh_jti_hash: str | None
refresh_expires_at: datetime | None
user_agent: str | None = None
first_ip: str | None = None
last_ip: str | None = None

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import date, datetime
@dataclass(slots=True)
class UserEntity:
id: str | None = None
email: str | None = None
password_hash: str | None = None
first_name: str | None = None
middle_name: str | None = None
last_name: str | None = None
birth_date: date | None = None
crypto_wallet: str | None = None
phone: str | None = None
bik: str | None = None
account_number: str | None = None
card_number: str | None = None
inn: str | None = None
kyc_verified: bool | None = None
is_deleted: bool | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
kyc_verified_at: datetime | None = None

View File

@@ -0,0 +1,2 @@
from src.application.domain.enums.log_level import LogLevel
from src.application.domain.enums.log_format import LogFormat

View File

@@ -0,0 +1,7 @@
from enum import Enum
class LogFormat(Enum):
"""Enum for supported log formats"""
TEXT = 'text'
JSON = 'json'

View File

@@ -0,0 +1,54 @@
from enum import Enum
class LogLevel(Enum):
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
CRITICAL = 50
EXCEPTION = 60
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"[{self.value}, '{self.name}']"
def __eq__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value == other.value
if isinstance(other, int):
return self.value == other
return NotImplemented
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __lt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value < other.value
if isinstance(other, int):
return self.value < other
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value <= other.value
if isinstance(other, int):
return self.value <= other
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value > other.value
if isinstance(other, int):
return self.value > other
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value >= other.value
if isinstance(other, int):
return self.value >= other
return NotImplemented

View File

@@ -0,0 +1 @@
from src.application.domain.exceptions.application_exceptions import ApplicationException

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Mapping
class ApplicationException(Exception):
def __init__(
self,
status_code: int,
message: str,
headers: Mapping[str, str] | None = None,
):
super().__init__(message)
self.status_code = status_code
self.message = message
self.headers = headers
def __str__(self):
return f"{self.status_code}: {self.message}"

2
src/infrastructure/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,2 @@
from src.infrastructure.cache.client import create_redis_client
from src.infrastructure.cache.keydb_client import KeydbCache

16
src/infrastructure/cache/client.py vendored Normal file
View File

@@ -0,0 +1,16 @@
import redis.asyncio as redis
from redis.asyncio.client import Redis
from src.infrastructure.config import settings
def create_redis_client() -> Redis:
return redis.from_url(
settings.REDIS_URL,
max_connections=50,
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
retry_on_timeout=True,
socket_keepalive=True,
)

View File

@@ -0,0 +1,20 @@
from redis.asyncio.client import Redis
from src.application.contracts import ICache
class KeydbCache(ICache):
def __init__(self, redis_client: Redis):
self._r = redis_client
async def set(self, key: str, value: str, ttl: int) -> None:
return bool(await self._r.set(key, value, ex=ttl))
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
return bool(await self._r.set(key, value, ex=ttl, nx=True))
async def get(self, key: str) -> str | None:
return await self._r.get(key)
async def delete(self, key: str) -> bool:
deleted = await self._r.delete(key)
return deleted > 0

View File

@@ -0,0 +1 @@
from src.infrastructure.config.settings import settings

View File

@@ -0,0 +1,252 @@
from __future__ import annotations
from functools import lru_cache
from typing import List, Literal
import os
from dotenv import load_dotenv, find_dotenv
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from src.infrastructure.vault import create_hvac_client_from_approle, read_kv2_secret
env_file = find_dotenv(".env")
if env_file:
load_dotenv(env_file)
def normalize_vault_base_url(raw: str) -> str:
u = raw.strip().rstrip('/')
if not u:
return raw.strip()
if '://' not in u:
return f'https://{u}'
return u
class Settings(BaseSettings):
VAULT_ADDR: str = Field(default='http://localhost:8200')
VAULT_ROLE_ID: str = Field(..., description='AppRole role_id')
VAULT_SECRET_ID: str = Field(
...,
description='AppRole secret_id',
validation_alias=AliasChoices('VAULT_SECRET_ID', 'VAULT_SECRET_TOKEN'),
)
VAULT_NAMESPACE: str | None = Field(default=None)
VAULT_MOUNT_POINT: str = Field(default='dev-secrets')
VAULT_JWT_KID_PATH: str = 'jwt/kid'
VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids'
JWT_KEYS_REFRESH_SECONDS: int = 3600
DATABASE_HOST: str
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
DATABASE_NAME: str
DATABASE_USER: str
DATABASE_PASSWORD: str
DATABASE_POOL_SIZE: int = 10
DATABASE_MAX_OVERFLOW: int = 20
DATABASE_POOL_TIMEOUT: int = 30
DATABASE_POOL_RECYCLE: int = 3600
DATABASE_ECHO: bool = False
CSRF_SECRET_KEY: str = Field(min_length=32)
CSRF_COOKIE_SECURE: bool = False
CSRF_COOKIE_HTTPONLY: bool = True
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
CSRF_COOKIE_PATH: str = '/'
CSRF_COOKIE_DOMAIN: str | None = None
DOCS_USERNAME: str = 'admin'
DOCS_PASSWORD: str = 'admin'
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
JWT_ISSUER: str | None = None
JWT_AUDIENCE: str | None = None
JWT_ALGORITHM: str = 'RS256'
REDIS_HOST: str = 'localhost'
REDIS_PORT: int = 6379
REDIS_PASSWORD: str | None = None
REDIS_DB: int = 0
RABBIT_HOST: str = 'localhost'
RABBIT_PORT: int = 5672
RABBIT_USER: str = 'guest'
RABBIT_PASSWORD: str = 'guest'
RABBIT_VHOST: str = '/'
RABBIT_PUBLISH_PERSIST: bool = True
RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
CORS_ORIGINS: str = 'http://localhost:3000'
CORS_ALLOW_CREDENTIALS: bool = True
RATE_LIMIT_REQUESTS: int = 60
RATE_LIMIT_WINDOW: int = 60
LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
LOG_FORMAT: Literal['JSON', 'TEXT'] = 'TEXT'
@field_validator('VAULT_ADDR', mode='before')
@classmethod
def vault_addr_scheme(cls, v):
if v is None or not isinstance(v, str):
return v
return normalize_vault_base_url(v)
@field_validator('CSRF_COOKIE_DOMAIN', mode='before')
@classmethod
def empty_csrf_domain_to_none(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return None
return v
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
case_sensitive=True,
extra='ignore',
populate_by_name=True,
)
@model_validator(mode='before')
@classmethod
def load_from_vault(cls, data: dict):
if not isinstance(data, dict):
return data
addr_raw = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
addr = normalize_vault_base_url(addr_raw)
data['VAULT_ADDR'] = addr
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
secret_id = (
data.get('VAULT_SECRET_ID')
or data.get('VAULT_SECRET_TOKEN')
or os.getenv('VAULT_SECRET_ID')
or os.getenv('VAULT_SECRET_TOKEN')
)
namespace = data.get('VAULT_NAMESPACE')
if namespace is None:
namespace = os.getenv('VAULT_NAMESPACE')
namespace = namespace if namespace else None
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'dev-secrets'
if not role_id or not secret_id:
raise RuntimeError(
'VAULT_ROLE_ID and VAULT_SECRET_ID (or VAULT_SECRET_TOKEN) are required for Vault AppRole'
)
data['VAULT_ROLE_ID'] = str(role_id).strip()
data['VAULT_SECRET_ID'] = str(secret_id).strip()
client = create_hvac_client_from_approle(
url=addr,
role_id=role_id,
secret_id=secret_id,
namespace=namespace,
timeout=5,
)
def read_secret(path: str) -> dict:
return read_kv2_secret(client=client, mount_point=mount, path=path)
def read_secret_optional(path: str) -> dict:
try:
return read_secret(path)
except Exception:
return {}
database = read_secret('database')
csrf = read_secret('csrf')
db_ci = {str(k).lower(): v for k, v in database.items()}
def db_nonempty(key: str) -> bool:
v = db_ci.get(key)
if v is None:
return False
if isinstance(v, str) and not v.strip():
return False
return True
required_db = ['host', 'name', 'user', 'password', 'port']
missing_db = [k for k in required_db if not db_nonempty(k)]
if missing_db:
raise RuntimeError(f'Vault secret database missing non-empty keys: {missing_db}')
data['DATABASE_HOST'] = str(db_ci['host']).strip()
data['DATABASE_PORT'] = int(db_ci['port'])
data['DATABASE_NAME'] = str(db_ci['name']).strip()
data['DATABASE_USER'] = str(db_ci['user']).strip()
data['DATABASE_PASSWORD'] = str(db_ci['password']).strip()
csrf_secret = None
for entry_key, entry_val in csrf.items():
if str(entry_key).lower() == 'key' and entry_val is not None and str(entry_val).strip():
csrf_secret = str(entry_val).strip()
break
if not csrf_secret:
raise RuntimeError(
'Vault secret at csrf must contain a non-empty field named key (e.g. key=...)'
)
data['CSRF_SECRET_KEY'] = csrf_secret
rabbit = read_secret_optional('rabbitmq')
if rabbit:
r_ci = {str(k).lower(): v for k, v in rabbit.items()}
def rb_set(field: str, env_key: str, *, as_int: bool = False) -> None:
v = r_ci.get(field)
if v is None:
return
if isinstance(v, str) and not v.strip():
return
if as_int:
data[env_key] = int(v)
else:
data[env_key] = str(v).strip()
rb_set('host', 'RABBIT_HOST')
rb_set('port', 'RABBIT_PORT', as_int=True)
rb_set('user', 'RABBIT_USER')
rb_set('password', 'RABBIT_PASSWORD')
rb_set('vhost', 'RABBIT_VHOST')
return data
def cors_origins_list(self) -> List[str]:
return [o.strip() for o in self.CORS_ORIGINS.split(',') if o.strip()]
@property
def DATABASE_URL(self) -> str:
return (
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
)
@property
def REDIS_URL(self) -> str:
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
@property
def RABBIT_URL(self) -> str:
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
@property
def EXCLUDED_PATHS(self) -> List[str]:
return ['/docs', '/redoc', '/openapi.json', '/ping', '/health']
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@@ -0,0 +1 @@
from src.infrastructure.context_vars.trace_id import trace_id_var

View File

@@ -0,0 +1,4 @@
from contextvars import ContextVar
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')

View File

@@ -0,0 +1 @@
from src.infrastructure.database.unit_of_work import UnitOfWork

View File

@@ -0,0 +1,22 @@
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio.engine import create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from typing import AsyncGenerator
from src.infrastructure.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session

View File

@@ -0,0 +1 @@
from src.infrastructure.database.decorators.transactional import transactional

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from functools import wraps
from typing import Callable, Awaitable, TypeVar, ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
async with self._unit_of_work:
return await method(self, *args, **kwargs)
return wrapper

View File

@@ -0,0 +1,6 @@
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.user import UserModel
from src.infrastructure.database.models.sessions import Session
__all__ = ['Base', 'UserModel', 'Session']

View File

@@ -0,0 +1,19 @@
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
class Base(AsyncAttrs, DeclarativeBase):
__abstract__ = True
def __repr__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
for col in self.__table__.columns)
return f"<{class_name}({attributes})>"
def __str__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
for col in self.__table__.columns
if getattr(self, col.name) is not None)
return f"{class_name}({attributes})"

View File

@@ -0,0 +1,3 @@
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin

View File

@@ -0,0 +1,16 @@
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
class AuditTimestampsMixin:
created_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,6 @@
from sqlalchemy import Boolean
from sqlalchemy.orm import Mapped, mapped_column
class SoftDeleteMixin:
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)

View File

@@ -0,0 +1,8 @@
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
class UlidPrimaryKeyMixin:
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))

View File

@@ -0,0 +1,50 @@
from datetime import datetime, timezone
from sqlalchemy import String, DateTime, ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
from src.infrastructure.database.models import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
__tablename__ = "sessions"
sid: Mapped[str] = mapped_column(
String(26),
unique=True,
index=True,
nullable=False,
default=lambda: str(ULID()),
)
user_id: Mapped[str] = mapped_column(
String(26),
ForeignKey("users.id", ondelete="CASCADE"),
index=True,
nullable=False,
)
device_id: Mapped[str] = mapped_column(
String(26),
nullable=False,
index=True,
)
user_agent: Mapped[str | None] = mapped_column(String(500))
first_ip: Mapped[str | None] = mapped_column(String(64))
last_ip: Mapped[str | None] = mapped_column(String(64))
last_seen_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from sqlalchemy import Boolean, Date, String, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
__tablename__ = 'users'
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View File

@@ -0,0 +1,2 @@
from src.infrastructure.database.repositories.user_repository import UserRepository
from src.infrastructure.database.repositories.session_repository import SessionRepository

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.application.contracts import ILogger
from src.application.domain.entities import SessionEntity
from src.application.abstractions.repositories import ISessionRepository
from src.infrastructure.database.models import Session
class SessionRepository(ISessionRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
res = await self._session.execute(select(Session).where(Session.sid == sid))
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def upsert_by_device(
self,
*,
user_id: str,
device_id: str,
sid: str,
refresh_jti_hash: str,
refresh_expires_at: datetime,
user_agent: str | None,
ip: str | None,
now: datetime,
) -> SessionEntity:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
m = Session(
sid=sid,
user_id=user_id,
device_id=device_id,
revoked_at=None,
last_seen_at=now,
refresh_jti_hash=refresh_jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
first_ip=ip,
last_ip=ip,
)
self._session.add(m)
await self._session.flush()
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
else:
m.sid = sid
m.revoked_at = None
m.last_seen_at = now
m.refresh_jti_hash = refresh_jti_hash
m.refresh_expires_at = refresh_expires_at
m.user_agent = user_agent
m.last_ip = ip
await self._session.flush()
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
# Интерфейс требует None -> просто делаем update и flush
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(revoked_at=now)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh(
self,
sid: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> None:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
values = {'last_seen_at': now}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh_if_match(
self,
*,
sid: str,
old_jti_hash: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> bool:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
res = await self._session.execute(
update(Session)
.where(
Session.sid == sid,
Session.revoked_at.is_(None),
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
)
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
return (res.rowcount or 0) > 0

View File

@@ -0,0 +1,114 @@
from __future__ import annotations
from fastapi import status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException
from src.application.abstractions.repositories import IUserRepository
from src.application.domain.entities import UserEntity
from src.infrastructure.database.models import UserModel
class UserRepository(IUserRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def create_user(self, email: str, password_hash: str) -> UserEntity:
user = UserModel(email=email, password_hash=password_hash)
self._session.add(user)
try:
await self._session.flush()
return UserEntity(
id=user.id,
email=user.email,
created_at=user.created_at,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted
)
except IntegrityError:
self._logger.error(f'User already exists with email {user.email}')
raise ApplicationException(
status_code=status.HTTP_409_CONFLICT,
message='User with this email already exists',
)
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=f'Database error: {str(exception)}',
)
async def get_user_by_email(self, email: str) -> UserEntity:
try:
stmt = (
select(UserModel)
.where(
UserModel.email == email,
UserModel.is_deleted.is_(False),
)
)
result = await self._session.execute(stmt)
user: UserModel | None = result.scalar_one_or_none()
if user is None:
self._logger.warning(f'User not found with email {email}')
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found',)
return UserEntity(
id=user.id,
email=user.email,
password_hash=user.password_hash,
first_name=user.first_name,
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet,
phone=user.phone,
bik=user.bik,
account_number=user.account_number,
card_number=user.card_number,
inn=user.inn,
kyc_verified_at=user.kyc_verified_at,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted,
created_at=user.created_at,
updated_at=user.updated_at,
)
except ApplicationException:
raise
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=f'Database error: {str(exception)}',
)
async def exists_by_email(self, email: str) -> bool:
try:
stmt = (
select(UserModel.id)
.where(
UserModel.email == email,
UserModel.is_deleted.is_(False),
)
.limit(1)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none() is not None
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=f'Database error: {str(exception)}',
)

View File

@@ -0,0 +1,42 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
from src.application.contracts import ILogger
from src.infrastructure.database.repositories import UserRepository, SessionRepository
class UnitOfWork(IUnitOfWork):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
self.session_factory = session_factory
self._session: AsyncSession = None
self._user_repository: IUserRepository = None
self._session_repository: ISessionRepository = None
self._logger: ILogger = logger
async def __aenter__(self):
self._session = self.session_factory()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._logger.error(str(exc_val))
await self._session.rollback()
self._logger.error(f'Rollback: str{exc_val})')
else:
await self._session.flush()
await self._session.commit()
self._logger.debug('Commit')
await self._session.close()
@property
def user_repository(self) -> IUserRepository:
if self._user_repository is None:
self._user_repository = UserRepository(session=self._session, logger=self._logger)
return self._user_repository
@property
def session_repository(self) -> ISessionRepository:
if self._session_repository is None:
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
return self._session_repository

View File

@@ -0,0 +1,28 @@
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat
from src.application.domain.enums import LogLevel
from src.infrastructure.config.settings import settings
from src.infrastructure.logger.logger import Logger
log_levels = {
'DEBUG': LogLevel.DEBUG,
'INFO': LogLevel.INFO,
'WARNING': LogLevel.WARNING,
'ERROR': LogLevel.ERROR,
'CRITICAL': LogLevel.CRITICAL,
'EXCEPTION': LogLevel.EXCEPTION,
}
log_formats = {
'JSON': LogFormat.JSON,
'TEXT': LogFormat.TEXT,
}
logger = Logger(
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
)
def get_logger() -> ILogger:
return logger

View File

@@ -0,0 +1,129 @@
import traceback
import inspect
import sys
import json
from datetime import datetime
from typing import Callable, Optional, Any
from ulid import ULID
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat, LogLevel
from src.infrastructure.context_vars import trace_id_var
class Logger(ILogger):
_instance = None
__default_format = LogFormat.JSON
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
if cls._instance is None:
cls._instance = super(Logger, cls).__new__(cls)
return cls._instance
def __init__(
self,
log_format: LogFormat = __default_format,
min_level: LogLevel = LogLevel.INFO,
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
instance_id: str = "N/A",
):
self.log_format = log_format
self.min_level = min_level
self.id_generator = id_generator
self.instance_id = instance_id
def set_instance_id(self, instance_id: str) -> None:
self.instance_id = instance_id
def get_instance_id(self) -> str:
return self.instance_id
def set_format(self, log_format: LogFormat) -> None:
if not isinstance(log_format, LogFormat):
raise ValueError("Log format must be an instance of LogFormat enum")
self.log_format = log_format
def set_min_level(self, level: LogLevel) -> None:
self.min_level = level
def new_trace_id(self) -> str:
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
trace_id_var.set(trace_id)
return trace_id
def set_trace_id(self, trace_id: str) -> None:
trace_id_var.set(trace_id)
def get_trace_id(self) -> str:
return trace_id_var.get()
def clear_trace_id(self) -> None:
trace_id_var.set("N/A")
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
current_frame = inspect.currentframe()
if (
current_frame
and current_frame.f_back
and current_frame.f_back.f_back
and current_frame.f_back.f_back.f_back
):
frame = current_frame.f_back.f_back.f_back
filename = frame.f_code.co_filename
line_number = frame.f_lineno
else:
filename = "unknown"
line_number = 0
log_data = {
"timestamp": datetime.now().isoformat(),
"level": level.name,
"instance_id": self.instance_id,
"file": filename,
"line": line_number,
"trace_id": trace_id_var.get(),
"message": message,
}
if level == LogLevel.EXCEPTION:
log_data["exception"] = traceback.format_exc()
return log_data
def _log(self, level: LogLevel, message: str) -> None:
if level >= self.min_level:
log_data = self._prepare_log_data(level, message)
if self.log_format == LogFormat.JSON:
log_message = json.dumps(log_data, ensure_ascii=False)
else:
log_message = (
f"{log_data['timestamp']} - {log_data['level']} - "
f"{log_data['instance_id']} - {log_data['trace_id']} - "
f"{log_data['file']}:{log_data['line']} - "
f"{log_data['message']}"
)
if "exception" in log_data:
log_message += f"\nTraceback:\n{log_data['exception']}"
self._write(log_message)
def _write(self, message: str) -> None:
sys.stdout.write(message + "\n")
def debug(self, message: str) -> None:
self._log(LogLevel.DEBUG, message)
def info(self, message: str) -> None:
self._log(LogLevel.INFO, message)
def warning(self, message: str) -> None:
self._log(LogLevel.WARNING, message)
def error(self, message: str) -> None:
self._log(LogLevel.ERROR, message)
def critical(self, message: str) -> None:
self._log(LogLevel.CRITICAL, message)
def exception(self, message: str) -> None:
self._log(LogLevel.EXCEPTION, message)

View File

@@ -0,0 +1 @@
from src.infrastructure.messanger.rabbit_client import RabbitClient

View File

@@ -0,0 +1,72 @@
from typing import Any, Mapping
from faststream.rabbit import RabbitBroker
from src.application.contracts import IQueueMessanger
from src.infrastructure.config import settings
class RabbitClient(IQueueMessanger):
def __init__(self) -> None:
self._broker = RabbitBroker(
settings.RABBIT_URL,
)
self._connected = False
async def connect(self) -> None:
if self._connected:
return
await self._broker.connect()
self._connected = True
async def close(self) -> None:
if not self._connected:
return
await self._broker.close()
self._connected = False
async def _ensure_connected(self) -> None:
if not self._connected:
await self.connect()
async def publish_to_queue(
self,
queue: str,
message: Any,
*,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
queue=queue,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)
async def publish(
self,
message: Any,
*,
exchange: str,
routing_key: str,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
exchange=exchange,
routing_key=routing_key,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)

View File

@@ -0,0 +1,3 @@
from src.infrastructure.security.jwt import JwtService
from src.infrastructure.security.csrf import CsrfService
from src.infrastructure.security.hash import HashService

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import secrets
from typing import Any, Optional, Mapping
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from src.application.contracts import ICsrfService
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
class CsrfService(ICsrfService):
COOKIE_NAME = "csrf_token"
HEADER_NAME = "X-CSRF-Token"
SALT = "csrf"
TTL_SECONDS = 3600
def __init__(self) -> None:
self._serializer = URLSafeTimedSerializer(
secret_key=settings.CSRF_SECRET_KEY,
salt=self.SALT,
)
@property
def cookie_name(self) -> str:
return self.COOKIE_NAME
@property
def header_name(self) -> str:
return self.HEADER_NAME
@property
def ttl_seconds(self) -> int:
return self.TTL_SECONDS
def issue(self, subject: Optional[str] = None) -> str:
payload = {
"sub": subject,
"nonce": secrets.token_urlsafe(32),
}
return self._serializer.dumps(payload)
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
try:
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
except SignatureExpired:
raise ApplicationException(
status_code=403,
message="CSRF token expired",
)
except BadSignature:
raise ApplicationException(
status_code=403,
message="CSRF token invalid",
)
if expected_subject is not None and data.get("sub") != expected_subject:
raise ApplicationException(
status_code=403,
message="CSRF token subject mismatch",
)
return data
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
cookie_token = cookies.get(self.COOKIE_NAME)
header_token = headers.get(self.HEADER_NAME)
return cookie_token, header_token
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
if not cookie_token or not header_token:
raise ApplicationException(
status_code=403,
message="CSRF token missing",
)
if not secrets.compare_digest(cookie_token, header_token):
raise ApplicationException(
status_code=403,
message="CSRF token mismatch",
)
self.verify(cookie_token, expected_subject=expected_subject)

View File

@@ -0,0 +1,17 @@
import bcrypt
from src.application.contracts import IHashService, ILogger
class HashService(IHashService):
def __init__(self, logger: ILogger):
self._logger = logger
async def hash(self, value: str) -> str:
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
self._logger.info(f'Hash value {hashed_value.decode()}')
return hashed_value.decode()
async def verify(self, hashed_value: str, plain_value: str) -> bool:
self._logger.info(f'Hash value {hashed_value[:10]}')
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())

View File

@@ -0,0 +1,207 @@
from __future__ import annotations
from datetime import datetime, timezone, timedelta
from jose import jwt, ExpiredSignatureError, JWTError
from src.application.contracts import ILogger, IJwtService
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
from src.infrastructure.vault import JwtKeyStore
class JwtService(IJwtService):
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
self._logger = logger
self._key_store = key_store
async def create_access_token(self, user_id: str, sid: str) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
payload: dict[str, object] = {
'sub': user_id,
'type': 'access',
'sid': sid,
'iat': int(now.timestamp()),
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
}
if settings.JWT_ISSUER:
payload['iss'] = settings.JWT_ISSUER
if settings.JWT_AUDIENCE:
payload['aud'] = settings.JWT_AUDIENCE
try:
kid, private_pem = await self._key_store.get_signing_key()
token = jwt.encode(
payload,
private_pem,
algorithm=settings.JWT_ALGORITHM,
headers={'kid': kid},
)
self._logger.info(f'Access token created user_id={user_id} sid={sid} kid={kid}')
return token
except ApplicationException:
raise
except Exception as exception:
self._logger.error(f'JWT access signing failed user_id={user_id} sid={sid} error={str(exception)}')
raise ApplicationException(status_code=500, message='JWT signing failed')
async def create_refresh_token(self, user_id: str, sid: str, refresh_jti: str) -> str:
now = datetime.now(timezone.utc)
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
payload: dict[str, object] = {
'sub': user_id,
'type': 'refresh',
'sid': sid,
'jti': refresh_jti,
'iat': int(now.timestamp()),
'nbf': int(now.timestamp()),
'exp': int(exp.timestamp()),
}
if settings.JWT_ISSUER:
payload['iss'] = settings.JWT_ISSUER
if settings.JWT_AUDIENCE:
payload['aud'] = settings.JWT_AUDIENCE
try:
kid, private_pem = await self._key_store.get_signing_key()
token = jwt.encode(
payload,
private_pem,
algorithm=settings.JWT_ALGORITHM,
headers={'kid': kid},
)
self._logger.info(f'Refresh token created user_id={user_id} sid={sid} jti={refresh_jti} kid={kid}')
return token
except ApplicationException:
raise
except Exception as exception:
self._logger.error(f'JWT refresh signing failed user_id={user_id} sid={sid} error={str(exception)}')
raise ApplicationException(status_code=500, message='JWT signing failed')
async def decode_access_token(self, token: str) -> AccessTokenPayload:
payload = await self._decode_and_verify(token)
if payload.get('type') != 'access':
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
raise ApplicationException(status_code=401, message='Invalid token type')
try:
return AccessTokenPayload(
sub=str(payload['sub']),
type='access',
sid=str(payload['sid']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
iss=payload.get('iss'),
aud=payload.get('aud'),
)
except KeyError as exception:
self._logger.warning(f'Access token missing claim error={str(exception)}')
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
payload = await self._decode_and_verify(token)
if payload.get('type') != 'refresh':
self._logger.warning(f'Refresh token invalid type received_type={payload.get('type')}')
raise ApplicationException(status_code=401, message='Invalid token type')
try:
return RefreshTokenPayload(
sub=str(payload['sub']),
type='refresh',
sid=str(payload['sid']),
jti=str(payload['jti']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
iss=payload.get('iss'),
aud=payload.get('aud'),
)
except KeyError as exception:
self._logger.warning(f'Refresh token missing claim error={str(exception)}')
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
async def _decode_and_verify(self, token: str) -> dict:
kid: str | None = None
try:
header = jwt.get_unverified_header(token)
kid = header.get('kid')
if not kid:
self._logger.warning(f'JWT header missing kid header={header}')
raise ApplicationException(status_code=401, message='Missing token header: kid')
if header.get('alg') != settings.JWT_ALGORITHM:
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={header.get('alg')} expected_alg={settings.JWT_ALGORITHM}')
raise ApplicationException(status_code=401, message='Invalid token algorithm')
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.info(f'JWT kid cache miss kid={kid} refreshing keystore')
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.warning(f'JWT unknown kid kid={kid}')
raise ApplicationException(status_code=401, message='Unknown token kid')
options = {
'verify_signature': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iat': True,
'verify_aud': bool(settings.JWT_AUDIENCE),
'verify_iss': bool(settings.JWT_ISSUER),
'require_exp': True,
'require_iat': True,
'require_nbf': True,
'require_sub': True,
'require_sid': True,
'require_type': True,
'leeway': 10,
}
payload = jwt.decode(
token,
public_pem,
algorithms=[settings.JWT_ALGORITHM],
audience=settings.JWT_AUDIENCE or None,
issuer=settings.JWT_ISSUER or None,
options=options,
)
if options.get('require_sid') and 'sid' not in payload:
self._logger.warning(f'JWT missing sid claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: sid')
if options.get('require_type') and 'type' not in payload:
self._logger.warning(f'JWT missing type claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: type')
return payload
except ExpiredSignatureError as exception:
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Token expired')
except ApplicationException:
raise
except JWTError as exception:
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Invalid token')
except Exception as exception:
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
raise ApplicationException(status_code=500, message='JWT decode failed')

View File

@@ -0,0 +1 @@
from src.infrastructure.utils.instance_id import generate_instance_id

View File

@@ -0,0 +1,14 @@
from ulid import ULID
def generate_instance_id() -> str:
"""
Generate a process-wide instance id in ULID format.
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
"""
return str(ULID())

View File

@@ -0,0 +1,3 @@
from src.infrastructure.vault.utils import create_hvac_client_from_approle, read_kv2_secret
from src.infrastructure.vault.keys import JwtKeyStore
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler

View File

@@ -0,0 +1,118 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from src.application.domain.dto import JwtKeySet, JwtKeyPair
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.vault.utils import create_hvac_client_from_approle, read_kv2_secret
class JwtKeyStore:
_instance: "JwtKeyStore | None" = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
*,
vault_addr: str,
vault_role_id: str,
vault_secret_id: str,
vault_namespace: str | None,
mount_point: str,
kid_path: str = 'jwt/kid',
kids_prefix: str = 'jwt/kids',
timeout_seconds: int = 5,
):
if getattr(self, '_initialized', False):
return
self._vault_addr = vault_addr
self._vault_role_id = vault_role_id
self._vault_secret_id = vault_secret_id
self._vault_namespace = vault_namespace
self._timeout = timeout_seconds
self._mount = mount_point
self._kid_path = kid_path
self._kids_prefix = kids_prefix
self._lock = asyncio.Lock()
self._keyset: JwtKeySet | None = None
self._last_refresh_at: datetime | None = None
self._initialized = True
@classmethod
def get_instance(cls) -> 'JwtKeyStore':
if cls._instance is None:
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
return cls._instance
def _read_keyset_sync(self) -> JwtKeySet:
client = create_hvac_client_from_approle(
url=self._vault_addr,
role_id=self._vault_role_id,
secret_id=self._vault_secret_id,
namespace=self._vault_namespace,
timeout=self._timeout,
)
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
active_kid = kids.get('active')
previous_kid = kids.get('previous')
if not active_kid:
raise RuntimeError('Vault jwt/kid secret missing active')
active_pair = self._read_keypair_sync(client, active_kid)
prev_pair = None
if previous_kid and previous_kid != active_kid:
prev_pair = self._read_keypair_sync(client, previous_kid)
return JwtKeySet(active=active_pair, previous=prev_pair)
def _read_keypair_sync(self, client, kid: str) -> JwtKeyPair:
data = read_kv2_secret(
client=client,
mount_point=self._mount,
path=f'{self._kids_prefix}/{kid}',
)
priv = data.get('private_key')
pub = data.get('public_key')
if not priv or not pub:
raise RuntimeError(f'Vault jwt/kids/{kid} missing private_key/public_key')
return JwtKeyPair(kid=kid, private_key_pem=priv, public_key_pem=pub)
async def refresh(self) -> JwtKeySet:
keyset = await asyncio.to_thread(self._read_keyset_sync)
async with self._lock:
self._keyset = keyset
self._last_refresh_at = datetime.now(timezone.utc)
return keyset
async def get_signing_key(self) -> tuple[str, str]:
ks = await self._get_or_refresh()
return ks.active.kid, ks.active.private_key_pem
async def get_public_key_for_kid(self, kid: str) -> str | None:
ks = await self._get_or_refresh()
return ks.public_keys_by_kid().get(kid)
async def last_refresh_at(self) -> datetime | None:
async with self._lock:
return self._last_refresh_at
async def _get_or_refresh(self) -> JwtKeySet:
async with self._lock:
ks = self._keyset
return ks if ks else await self.refresh()

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from src.infrastructure.vault import JwtKeyStore
logger = logging.getLogger(__name__)
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
scheduler = AsyncIOScheduler()
scheduler.add_job(
store.refresh,
trigger=IntervalTrigger(seconds=refresh_seconds),
id="jwt_keys_refresh",
replace_existing=True,
max_instances=1,
coalesce=True,
misfire_grace_time=60,
)
scheduler.start()
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
return scheduler

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import hvac
def create_hvac_client_from_approle(
*,
url: str,
role_id: str,
secret_id: str,
namespace: str | None = None,
timeout: int = 5,
) -> hvac.Client:
kwargs: dict = {'url': url, 'timeout': timeout}
if namespace:
kwargs['namespace'] = namespace
client = hvac.Client(**kwargs)
client.auth.approle.login(role_id=role_id, secret_id=secret_id)
if not client.is_authenticated():
raise RuntimeError(
'Vault AppRole authentication failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID'
)
return client
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
secret = client.secrets.kv.v2.read_secret_version(
mount_point=mount_point,
path=path,
)
return secret["data"]["data"]

154
src/main.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
from contextlib import asynccontextmanager
import secrets
from typing import AsyncGenerator
from fastapi import Depends, FastAPI, status
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.cors import CORSMiddleware
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.cache import create_redis_client
from src.infrastructure.config.settings import get_settings
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
from src.infrastructure.utils import generate_instance_id
from src.infrastructure.logger import logger
from src.infrastructure.config import settings
from src.presentation.dependencies import get_rabbit
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
from src.presentation.routing import v1_router
security = HTTPBasic()
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
if not (user_ok and pass_ok):
raise ApplicationException(
status_code=status.HTTP_401_UNAUTHORIZED,
message='Unauthorized',
headers={'WWW-Authenticate': 'Basic'},
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
instance_id = generate_instance_id()
logger.set_instance_id(instance_id)
logger.info(f'Auth service instance started with id {instance_id}')
jwt_store = JwtKeyStore(
vault_addr=settings.VAULT_ADDR,
vault_role_id=settings.VAULT_ROLE_ID,
vault_secret_id=settings.VAULT_SECRET_ID,
vault_namespace=settings.VAULT_NAMESPACE,
mount_point=settings.VAULT_MOUNT_POINT,
kid_path=settings.VAULT_JWT_KID_PATH,
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
)
await jwt_store.refresh()
jwt_scheduler = start_jwt_keys_scheduler(
jwt_store,
refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS,
)
app.state.jwt_key_store = jwt_store
app.state.jwt_keys_scheduler = jwt_scheduler
redis_client = create_redis_client()
await get_rabbit().connect()
logger.info('Rabbit connected')
try:
await redis_client.ping()
app.state.redis = redis_client
logger.info('Redis connected')
yield
finally:
logger.info('Shutting down...')
sched = getattr(app.state, 'jwt_keys_scheduler', None)
if sched:
sched.shutdown(wait=False)
await redis_client.close()
await redis_client.connection_pool.disconnect()
await get_rabbit().close()
logger.info('Redis disconnected')
logger.info('API stopped')
app: FastAPI = FastAPI(
redoc_url=None,
docs_url=None,
lifespan=lifespan,
title='Bitforce. Auth Service',
version='1.0.0',
description='',
license_info={
'name': 'MIT',
'url': 'https://opensource.org/licenses/MIT',
},
)
app.add_exception_handler(ApplicationException, application_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
app.include_router(v1_router)
app.add_middleware(TraceIDMiddleware, logger=logger)
app.add_middleware(
SecurityHeadersMiddleware,
hsts=True,
hsts_preload=False,
frame_options='DENY',
referrer_policy='strict-origin-when-cross-origin',
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list(),
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
allow_methods=['*'],
allow_headers=['*'],
)
@app.get('/docs', include_in_schema=False)
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
'''Custom Swagger documentation, optionally protected with basic authentication.'''
return get_swagger_ui_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
)
@app.get('/redoc', include_in_schema=False)
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
return get_redoc_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
)
@app.post('/ping')
async def ping() -> dict[str, str]:
return {
'message': 'pong',
'status': 'ok',
}

View File

@@ -0,0 +1,2 @@
from src.presentation.decorators.csrf import csrf_protect
from src.presentation.decorators.rate_limit import rate_limit, _email_rl_key as email_rl_key

View File

@@ -0,0 +1,36 @@
from fastapi import Depends, Request
from fastapi.security.utils import get_authorization_scheme_param
from src.application.contracts import IJwtService
from src.application.domain.dto import AuthContext
from src.application.domain.exceptions import ApplicationException
from src.presentation.dependencies import get_jwt_service
def _extract_access_token(request: Request) -> str | None:
token = request.cookies.get("access_token")
if token:
return token
auth = request.headers.get("Authorization")
if auth:
scheme, param = get_authorization_scheme_param(auth)
if scheme.lower() == "bearer" and param:
return param
return None
async def require_access_token(
request: Request,
jwt_service: IJwtService = Depends(get_jwt_service), # твой DI
) -> AuthContext:
token = _extract_access_token(request)
if not token:
raise ApplicationException(status_code=401, message="Not authenticated")
payload = jwt_service.decode_access_token(token)
if payload.type != "access":
raise ApplicationException(status_code=401, message="Invalid token type")
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import inspect
from functools import wraps
from typing import Callable, Awaitable, Any, Optional, Annotated
from fastapi import Request, Header
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.security import CsrfService
def csrf_protect(
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
):
def decorator(func: Callable[..., Awaitable[Any]]):
sig = inspect.signature(func)
params = list(sig.parameters.values())
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
if not has_request:
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
has_header = any(p.name == 'x_csrf_token' for p in params)
if not has_header:
params.append(
inspect.Parameter(
name='x_csrf_token',
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
)
)
@wraps(func)
async def wrapper(*args, **kwargs):
request: Request | None = kwargs.get('request')
if request is None:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if request is None:
raise ApplicationException(
status_code=500,
message='Request is required for CSRF protection',
)
csrf = CsrfService()
cookie_token, _ = csrf.extract(request.cookies, request.headers)
header_token = kwargs.get('x_csrf_token')
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
csrf.verify_pair(cookie_token, header_token, expected_subject)
kwargs.pop('x_csrf_token', None)
return await func(*args, **kwargs)
wrapper.__signature__ = sig.replace(parameters=params)
return wrapper
return decorator

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
import functools
import inspect
import hashlib
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
from fastapi import Request
from redis.asyncio.client import Redis
from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.logger import get_logger
from src.presentation.dependencies import get_redis
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
req = kwargs.get('request')
if isinstance(req, Request):
return req
for a in args:
if isinstance(a, Request):
return a
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
def _client_ip(request: Request) -> str:
xff = request.headers.get('x-forwarded-for')
if xff:
return xff.split(',')[0].strip()
if request.client:
return request.client.host
return 'unknown'
_LUA_INCR_EXPIRE_TTL = '''
local key = KEYS[1]
local window = tonumber(ARGV[1])
local current = redis.call('INCR', key)
if current == 1 then
redis.call('EXPIRE', key, window)
end
local ttl = redis.call('TTL', key)
return { current, ttl }
'''
Scope = Literal['ip', 'device', 'user', 'key']
@runtime_checkable
class KeyBuilder1(Protocol):
def __call__(self, request: Request) -> str: ...
@runtime_checkable
class KeyBuilder3(Protocol):
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
KeyBuilder = KeyBuilder1 | KeyBuilder3
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
try:
sig = inspect.signature(builder)
if len(sig.parameters) >= 3:
return builder(request, args, kwargs)
return builder(request)
except Exception as e:
try:
return builder(request, args, kwargs)
except Exception:
raise e
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
body = kwargs.get('body')
if body is None and args:
for a in args:
if hasattr(a, 'email'):
body = a
break
email = (getattr(body, 'email', '') or '').strip().lower()
if not email:
email = _client_ip(request)
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
return f'email:{digest}'
def rate_limit(
*,
limit: int,
window_seconds: int,
scope: Scope = 'ip',
key_prefix: str = 'rl',
key_builder: Optional[KeyBuilder] = None,
fail_open: bool = True,
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
if limit <= 0:
raise ValueError('rate_limit: limit must be > 0')
if window_seconds <= 0:
raise ValueError('rate_limit: window_seconds must be > 0')
if scope == 'key' and not key_builder:
raise ValueError('rate_limit: scope="key" requires key_builder')
def decorator(func: Callable[..., Awaitable[Any]]):
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any):
request = _find_request(args, kwargs)
logger: ILogger = get_logger()
if scope == 'ip':
ident = _client_ip(request)
elif scope == 'device':
ident = request.cookies.get('device_id') or _client_ip(request)
elif scope == 'user':
user = getattr(request.state, 'user', None)
user_id = getattr(user, 'id', None) if user else None
ident = str(user_id) if user_id else _client_ip(request)
else:
try:
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
except Exception as e:
logger.error(f'RateLimit key_builder failed error={str(e)}')
raise ApplicationException(500, 'Rate limiter key_builder failed')
route = request.url.path
method = request.method
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
try:
redis: Redis = get_redis(request)
result = await redis.eval(
_LUA_INCR_EXPIRE_TTL,
1,
redis_key,
str(window_seconds),
)
count = int(result[0])
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
ttl = window_seconds if ttl_raw < 0 else ttl_raw
except Exception as e:
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
if fail_open:
logger.warning(f'RateLimit fail-open activated key={redis_key}')
return await func(*args, **kwargs)
raise ApplicationException(503, 'Rate limiter unavailable')
if count > limit:
retry_after = max(ttl, 0)
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
raise ApplicationException(
status_code=429,
message='Too Many Requests',
headers={'Retry-After': str(retry_after)},
)
logger.debug(f'RateLimit passed key={redis_key} count={count}')
return await func(*args, **kwargs)
return wrapper
return decorator

View File

@@ -0,0 +1,11 @@
from src.presentation.dependencies.commands import (
get_user_registration_complete_command,
get_user_login_start_command,
get_user_login_complete_command,
get_user_logout_command,
get_user_registration_start_command,
get_jwt_refresh_command
)
from src.presentation.dependencies.security import get_jwt_service, get_jwt_service
from src.presentation.dependencies.cache import get_redis
from src.presentation.dependencies.queue_messanger import get_rabbit

View File

@@ -0,0 +1,12 @@
from fastapi import Depends, Request
from redis.asyncio.client import Redis
from src.application.contracts import ICache
from src.infrastructure.cache import KeydbCache
def get_redis(request: Request) -> Redis:
return request.app.state.redis
def get_cache(redis_client: Redis = Depends(get_redis)) -> ICache:
return KeydbCache(redis_client)

View File

@@ -0,0 +1,98 @@
from fastapi import Depends
from src.application.abstractions import IUnitOfWork
from src.application.commands import (
UserRegistrationCompleteCommand,
JwtRefreshCommand,
UserRegistrationStartCommand,
UserLogoutCommand,
UserLoginCompleteCommand,
UserLoginStartCommand
)
from src.application.contracts import IHashService, IJwtService, ILogger, IQueueMessanger
from src.application.contracts import ICache
from src.presentation.dependencies.queue_messanger import get_rabbit
from src.presentation.dependencies.cache import get_cache
from src.presentation.dependencies.logger import get_logger
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
from src.presentation.dependencies.unit_of_work import get_unit_of_work
def get_user_registration_start_command(
logger: ILogger = Depends(get_logger),
hash_service: IHashService = Depends(get_hash_service),
cache: ICache = Depends(get_cache),
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
messanger: IQueueMessanger = Depends(get_rabbit),
) -> UserRegistrationStartCommand:
return UserRegistrationStartCommand(
logger=logger,
unit_of_work=unit_of_work,
hash_service=hash_service,
cache=cache,
messanger=messanger,
)
def get_user_registration_complete_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
logger: ILogger = Depends(get_logger),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
) -> UserRegistrationCompleteCommand:
return UserRegistrationCompleteCommand(
unit_of_work=uow,
logger=logger,
hash_service=hash_service,
jwt_service=jwt_service,
cache=cache
)
def get_user_login_start_command(
logger: ILogger = Depends(get_logger),
hash_service: IHashService = Depends(get_hash_service),
cache: ICache = Depends(get_cache),
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
messanger: IQueueMessanger = Depends(get_rabbit),
) -> UserLoginStartCommand:
return UserLoginStartCommand(
logger=logger,
unit_of_work=unit_of_work,
hash_service=hash_service,
cache=cache,
messanger=messanger,
)
def get_user_login_complete_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
logger: ILogger = Depends(get_logger),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
) -> UserLoginCompleteCommand:
return UserLoginCompleteCommand(
unit_of_work=uow,
logger=logger,
hash_service=hash_service,
jwt_service=jwt_service,
cache=cache
)
def get_user_logout_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
jwt_service: IJwtService = Depends(get_jwt_service),
logger: ILogger = Depends(get_logger),
) -> UserLogoutCommand:
return UserLogoutCommand(unit_of_work=uow, logger=logger, jwt_service=jwt_service)
def get_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
logger: ILogger = Depends(get_logger),
) -> JwtRefreshCommand:
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)

View File

@@ -0,0 +1,7 @@
from functools import lru_cache
from src.application.contracts import ILogger
from src.infrastructure.logger import logger
@lru_cache
def get_logger() -> ILogger:
return logger

View File

@@ -0,0 +1,8 @@
from functools import lru_cache
from src.application.contracts import IQueueMessanger
from src.infrastructure.messanger import RabbitClient
@lru_cache(maxsize=1)
def get_rabbit() -> IQueueMessanger:
return RabbitClient()

View File

@@ -0,0 +1,25 @@
from functools import lru_cache
from fastapi import Depends
from src.application.contracts import IHashService, IJwtService, ILogger
from src.infrastructure.security import HashService, JwtService
from src.infrastructure.vault import JwtKeyStore
from src.presentation.dependencies.logger import get_logger
@lru_cache(maxsize=1)
def _hash_service(logger: ILogger) -> IHashService:
return HashService(logger=logger)
def get_hash_service(logger: ILogger = Depends(get_logger)) -> IHashService:
return _hash_service(logger)
@lru_cache(maxsize=1)
def _jwt_service(logger: ILogger) -> IJwtService:
key_store = JwtKeyStore.get_instance()
return JwtService(logger=logger, key_store=key_store)
def get_jwt_service(logger: ILogger = Depends(get_logger)) -> IJwtService:
return _jwt_service(logger)

View File

@@ -0,0 +1,10 @@
from fastapi import Depends
from src.application.abstractions import IUnitOfWork
from src.application.contracts import ILogger
from src.infrastructure.database import UnitOfWork
from src.infrastructure.database.context import async_session_maker
from src.infrastructure.logger import get_logger
def get_unit_of_work(logger: ILogger = Depends(get_logger)) -> IUnitOfWork:
return UnitOfWork(session_factory=async_session_maker, logger=logger)

View File

@@ -0,0 +1,2 @@
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
from src.presentation.handlers.application_handler import application_exception_handler

View File

@@ -0,0 +1,17 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from src.application.domain.exceptions import ApplicationException
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
detail = exc.message
if 500 <= exc.status_code:
detail = "Internal Server Error"
return ORJSONResponse(
status_code=exc.status_code,
content={"detail": detail},
headers=dict(exc.headers) if exc.headers else None,
)

View File

@@ -0,0 +1,12 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from starlette import status
from src.infrastructure.logger import logger
async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJSONResponse:
logger.exception(f'Unhandled exception: {type(exc).__name__}')
return ORJSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'detail': 'Internal Server Error'},
)

View File

@@ -0,0 +1,2 @@
from src.presentation.middleware.trace_id import TraceIDMiddleware
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware

View File

@@ -0,0 +1,51 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
*,
hsts: bool = True,
hsts_max_age: int = 31536000, # 1 год
hsts_include_subdomains: bool = True,
hsts_preload: bool = False,
frame_options: str = 'DENY', # или 'SAMEORIGIN'
referrer_policy: str = 'strict-origin-when-cross-origin',
content_security_policy: str | None = None,
):
super().__init__(app)
self.hsts = hsts
self.hsts_max_age = hsts_max_age
self.hsts_include_subdomains = hsts_include_subdomains
self.hsts_preload = hsts_preload
self.frame_options = frame_options
self.referrer_policy = referrer_policy
self.csp = content_security_policy
async def dispatch(self, request: Request, call_next) -> Response:
response: Response = await call_next(request)
if request.url.path in ('/docs', '/redoc', '/openapi.json'):
return response
if self.hsts and request.url.scheme == 'https':
hsts = f'max-age={self.hsts_max_age}'
if self.hsts_include_subdomains:
hsts += '; includeSubDomains'
if self.hsts_preload:
hsts += '; preload'
response.headers['Strict-Transport-Security'] = hsts
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = self.frame_options
response.headers['Referrer-Policy'] = self.referrer_policy
if self.csp:
response.headers['Content-Security-Policy'] = self.csp
return response

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from typing import Optional
from contextvars import Token
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from ulid import ULID
from src.application.contracts import ILogger
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
class TraceIDMiddleware:
def __init__(
self,
app: ASGIApp,
logger: ILogger,
response_header_name: str = "X-Trace-ID",
attach_response_header: bool = True,
) -> None:
self.app = app
self.logger = logger
self.response_header_name = response_header_name
self.attach_response_header = attach_response_header
def _is_excluded(self, path: str) -> bool:
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope)
if self._is_excluded(request.url.path):
await self.app(scope, receive, send)
return
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
if not trace_id:
trace_id = str(ULID())
request.state.trace_id = trace_id
token: Token = trace_id_var.set(trace_id)
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
status_code_holder: dict[str, Optional[int]] = {"status": None}
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
status_code_holder["status"] = int(message["status"])
if self.attach_response_header:
headers = list(message.get("headers", []))
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
message["headers"] = headers
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
status = status_code_holder["status"]
status_part = f"{status}" if status is not None else "unknown"
self.logger.debug(
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
)
trace_id_var.reset(token)

View File

@@ -0,0 +1,9 @@
from fastapi import APIRouter
from src.presentation.routing.auth import auth_router
from src.presentation.routing.csrf import csrf_router
from src.presentation.routing.jwt import jwt_router
v1_router = APIRouter(prefix='/v1')
v1_router.include_router(auth_router)
v1_router.include_router(csrf_router)
v1_router.include_router(jwt_router)

View File

@@ -0,0 +1,222 @@
from fastapi import APIRouter, Depends, status, Request
from fastapi.responses import ORJSONResponse
from ulid import ULID
from src.application.commands import (
UserLogoutCommand,
UserRegistrationStartCommand,
UserLoginStartCommand,
UserRegistrationCompleteCommand,
UserLoginCompleteCommand
)
from src.application.contracts import ILogger
from src.application.domain.dto import UserLoginDto
from src.infrastructure.config import settings
from src.infrastructure.logger import get_logger
from src.presentation.decorators import rate_limit, email_rl_key
from src.presentation.dependencies import (
get_user_registration_complete_command,
get_user_logout_command,
get_user_registration_start_command,
get_user_login_start_command,
get_user_login_complete_command
)
from src.presentation.schemas import UserLogin, RegistrationStart, RegistrationComplete, LoginStart
#from src.presentation.decorators import csrf_protect
auth_router = APIRouter(prefix='/auth', tags=['auth'])
@auth_router.post(
path='/registration/start',
response_class=ORJSONResponse,
status_code=status.HTTP_200_OK,
)
@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
async def registration_start(
request: Request,
body: RegistrationStart,
command: UserRegistrationStartCommand = Depends(get_user_registration_start_command),
):
result = await command(body.email)
return {'success': result}
@auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED)
@rate_limit(limit=10, window_seconds=300, scope='ip')
async def registration(
request: Request,
user: RegistrationComplete,
command: UserRegistrationCompleteCommand = Depends(get_user_registration_complete_command),
logger: ILogger = Depends(get_logger),
):
device_id = request.cookies.get('device_id')
if not device_id:
device_id = str(ULID())
xff = request.headers.get('x-forwarded-for')
ip = xff.split(',')[0].strip() if xff else (request.client.host if request.client else None)
user_agent = request.headers.get('user-agent')
created = await command(
email=str(user.email),
password=user.password,
device_id=device_id,
code=user.code,
user_agent=user_agent,
ip=ip,
)
logger.info(f'Registration completed for user_id={created.id}')
response = ORJSONResponse(content={'id': created.id, 'email': created.email})
response.set_cookie(
key='device_id',
value=device_id,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=60 * 60 * 24 * 365 * 5
)
response.set_cookie(
key='access_token',
value=created.access_token,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
response.set_cookie(
key='refresh_token',
value=created.refresh_token,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
@auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
async def login_start(
request: Request,
body: LoginStart,
command: UserLoginStartCommand = Depends(get_user_login_start_command),
):
result = await command(body.email)
return {'success': result}
@auth_router.post(path='/login/compete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=10, window_seconds=300, scope='ip')
async def login(
request: Request,
user: UserLogin,
command: UserLoginCompleteCommand = Depends(get_user_login_complete_command),
logger: ILogger = Depends(get_logger)
):
device_id = request.cookies.get('device_id')
if not device_id:
device_id = str(ULID())
xff = request.headers.get('x-forwarded-for')
ip = xff.split(',')[0].strip() if xff else (request.client.host if request.client else None)
user_agent = request.headers.get('user-agent')
dto: UserLoginDto = await command(
email=str(user.email),
password=user.password,
code=user.code,
device_id=device_id,
user_agent=user_agent,
ip=ip,
)
logger.info(f'Login completed for user_id={dto.id}')
response = ORJSONResponse(
content={
'id': dto.id,
'email': dto.email,
'first_name': dto.first_name,
'middle_name': dto.middle_name,
'last_name': dto.last_name,
'birth_date': dto.birth_date.isoformat() if dto.birth_date else None,
'crypto_wallet': dto.crypto_wallet,
'phone': dto.phone,
'bik': dto.bik,
'account_number': dto.account_number,
'card_number': dto.card_number,
'inn': dto.inn,
'kyc_verified': dto.kyc_verified,
'kyc_verified_at': dto.kyc_verified_at,
'created_at': dto.created_at.isoformat() if dto.created_at else None,
'updated_at': dto.updated_at.isoformat() if dto.updated_at else None,
}
)
response.set_cookie(
key='device_id',
value=device_id,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=60 * 60 * 24 * 365 * 5
)
response.set_cookie(
key='access_token',
value=dto.access_token,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
response.set_cookie(
key='refresh_token',
value=dto.refresh_token,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
@auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def logout_current(
request: Request,
command: UserLogoutCommand = Depends(get_user_logout_command),
):
refresh_token = request.cookies.get('refresh_token')
await command(refresh_token=refresh_token)
response = ORJSONResponse({'ok': True})
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
# @auth_router.get(path='/ping')
# @csrf_protect()
# async def ping(request: Request):
# return ORJSONResponse(
# content={
# 'status': 'pong'
# }
# )

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from fastapi import APIRouter
from fastapi.responses import ORJSONResponse
from starlette import status
from src.infrastructure.security import CsrfService
from src.infrastructure.config import settings
from src.presentation.decorators import rate_limit
csrf_router = APIRouter(prefix='/csrf', tags=['csrf'])
@csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def issue_csrf_token():
csrf = CsrfService()
token = csrf.issue()
response = ORJSONResponse(
content={
'token': token,
'header_name': csrf.header_name,
}
)
response.set_cookie(
key=csrf.cookie_name,
value=token,
secure=settings.CSRF_COOKIE_SECURE,
httponly=settings.CSRF_COOKIE_HTTPONLY,
samesite=settings.CSRF_COOKIE_SAMESITE,
path=settings.CSRF_COOKIE_PATH,
domain=settings.CSRF_COOKIE_DOMAIN,
max_age=csrf.ttl_seconds,
)
return response

View File

@@ -0,0 +1,64 @@
from fastapi import APIRouter, Request, Depends
from fastapi.responses import ORJSONResponse
from starlette import status
from src.application.commands import JwtRefreshCommand
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config import settings
from src.presentation.decorators import rate_limit
from src.presentation.dependencies import get_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except ApplicationException:
response = ORJSONResponse({'result': False}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
response = ORJSONResponse({'result': True})
response.set_cookie(
key='access_token',
value=access,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
response.set_cookie(
key='refresh_token',
value=refresh,
httponly=True,
secure=True,
samesite='lax',
path='/',
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
# Usage
# @jwt_router.get("/test")
# async def profile(auth: AuthContext = Depends(require_access_token)):
# return 'ok'

View File

@@ -0,0 +1 @@
from src.presentation.schemas.user import RegistrationStart, RegistrationComplete, UserLogin, LoginStart

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import re
from typing import ClassVar
from pydantic import BaseModel, EmailStr, Field, ValidationError, field_validator, model_validator
class EmailNoSubaddressing(BaseModel):
email: EmailStr = Field(title='Email', description='Email without subaddressing')
@field_validator('email')
@classmethod
def validate_and_normalize_email(cls, v: EmailStr) -> str:
email = str(v).strip().lower()
local, _, domain = email.partition('@')
if not local or not domain:
raise ValueError('Invalid email')
if '+' in local:
raise ValueError('Email subaddressing is not allowed')
if any(ord(ch) > 127 for ch in local):
raise ValueError('Email must be ASCII')
if local.startswith('.') or local.endswith('.') or '..' in local:
raise ValueError('Invalid email local part')
if not re.fullmatch(r'[A-Za-z0-9._-]+', local):
raise ValueError('Email contains запрещенные символы')
return email
class RegistrationStart(EmailNoSubaddressing):
pass
class LoginStart(EmailNoSubaddressing):
pass
class RegistrationComplete(EmailNoSubaddressing):
password: str = Field(min_length=12)
confirm_password: str = Field(min_length=12)
code: str = Field(
min_length=6,
max_length=6,
pattern=r"^\d{6}$",
)
_allowed_specials: ClassVar[str] = '!@#$%^&*()_+-=.,:;?/[]{}<>'
@field_validator('password')
@classmethod
def validate_password_policy(cls, v: str) -> str:
if len(v) < 12:
raise ValueError('Password must be at least 12 characters long')
if not any(c.islower() for c in v):
raise ValueError('Password must contain at least one lowercase letter')
if not any(c.isupper() for c in v):
raise ValueError('Password must contain at least one uppercase letter')
if not any(c.isdigit() for c in v):
raise ValueError('Password must contain at least one digit')
if not any(c in cls._allowed_specials for c in v):
raise ValueError(
'Password must contain at least one special character '
f'from: {cls._allowed_specials}'
)
if any(c.isspace() for c in v):
raise ValueError('Password must not contain whitespace')
return v
@model_validator(mode='after')
def validate_password_confirmation(self) -> 'RegistrationComplete':
if self.password != self.confirm_password:
raise ValidationError.from_exception_data(
title='Passwords do not match',
line_errors=[{
'type': 'value_error',
'loc': ('confirm_password',),
'msg': 'Passwords do not match',
'input': self.confirm_password,
}],
)
return self
class UserLogin(EmailNoSubaddressing):
password: str = Field(min_length=12)
code: str = Field(
min_length=6,
max_length=6,
pattern=r"^\d{6}$",
)

1205
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff