Initial commit
This commit is contained in:
145
.gitignore
vendored
Normal file
145
.gitignore
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
*.pyd
|
||||
*.dll
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache/
|
||||
.pytest_cache/
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
|
||||
# Type checkers / linters
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
.pyre/
|
||||
.pytype/
|
||||
.ruff_cache/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.env.*
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Poetry
|
||||
poetry.lock
|
||||
|
||||
# Pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# Hatch
|
||||
.hatch/
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Local databases
|
||||
*.sqlite3
|
||||
*.db
|
||||
|
||||
# Secrets / credentials
|
||||
secrets.json
|
||||
credentials.json
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
|
||||
# PyCharm / IntelliJ IDEA
|
||||
.idea/
|
||||
*.iml
|
||||
out/
|
||||
|
||||
# VS Code (optional)
|
||||
.vscode/
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Sphinx docs
|
||||
docs/_build/
|
||||
|
||||
# mkdocs
|
||||
site/
|
||||
|
||||
# celery
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# mypy compiled cache
|
||||
.mypy_cache/
|
||||
|
||||
# pyinstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# pytest debug
|
||||
pytestdebug.log
|
||||
|
||||
# Local config overrides
|
||||
config.local.py
|
||||
settings.local.py
|
||||
|
||||
# Vault / local dev secrets
|
||||
.env.vault
|
||||
vault.token
|
||||
|
||||
.env
|
||||
.dockerignore
|
||||
/sql
|
||||
28
Dockerfile
Normal file
28
Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies (cached layer)
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
# Copy source last (fast rebuilds)
|
||||
COPY src ./src
|
||||
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm AS runtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Use the virtualenv created by `uv sync` in builder
|
||||
COPY --from=builder /app/.venv /app/.venv
|
||||
COPY --from=builder /app/src /app/src
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH" \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
PYTHONPATH=/app
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["sh", "-c", "granian --interface asgi ${APP_MODULE:-src.main:app} --host ${APP_HOST:-0.0.0.0} --port ${APP_PORT:-8000} --workers ${APP_WORKERS:-1} --loop uvloop"]
|
||||
83
docker-compose.yml
Normal file
83
docker-compose.yml
Normal file
@@ -0,0 +1,83 @@
|
||||
services:
|
||||
auth:
|
||||
container_name: auth-service
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
PYTHONUNBUFFERED: "1"
|
||||
APP_MODULE: "src.main:app"
|
||||
APP_HOST: "0.0.0.0"
|
||||
APP_PORT: "8000"
|
||||
APP_WORKERS: "1"
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
keydb:
|
||||
condition: service_healthy
|
||||
restart: no
|
||||
|
||||
keydb:
|
||||
image: eqalpha/keydb
|
||||
container_name: keydb
|
||||
restart: no
|
||||
expose:
|
||||
- "6379"
|
||||
volumes:
|
||||
- keydb_data:/data
|
||||
command:
|
||||
- keydb-server
|
||||
- --requirepass
|
||||
- keydb
|
||||
- --dir
|
||||
- /data
|
||||
- --appendonly
|
||||
- "yes"
|
||||
- --appendfsync
|
||||
- everysec
|
||||
- --save
|
||||
- "900"
|
||||
- "1"
|
||||
- --save
|
||||
- "300"
|
||||
- "10"
|
||||
- --save
|
||||
- "60"
|
||||
- "10000"
|
||||
healthcheck:
|
||||
test: [ "CMD", "redis-cli", "-a", "keydb", "ping" ]
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
retries: 20
|
||||
|
||||
# keydb:
|
||||
# image: eqalpha/keydb
|
||||
# container_name: keydb
|
||||
# restart: no
|
||||
# expose:
|
||||
# - "6379"
|
||||
# volumes:
|
||||
# - keydb_data:/data
|
||||
# environment:
|
||||
# KEYDB_PASSWORD: keydb
|
||||
# command: >
|
||||
# sh -c "
|
||||
# keydb-server
|
||||
# --requirepass $$KEYDB_PASSWORD
|
||||
# --dir /data
|
||||
# --appendonly yes
|
||||
# --appendfsync everysec
|
||||
# --save 900 1
|
||||
# --save 300 10
|
||||
# --save 60 10000
|
||||
# "
|
||||
# healthcheck:
|
||||
# test: ["CMD", "redis-cli", "ping"]
|
||||
# interval: 5s
|
||||
# timeout: 2s
|
||||
# retries: 20
|
||||
|
||||
volumes:
|
||||
keydb_data:
|
||||
24
pyproject.toml
Normal file
24
pyproject.toml
Normal file
@@ -0,0 +1,24 @@
|
||||
[project]
|
||||
name = "bitok"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
requires-python = "==3.12.*"
|
||||
dependencies = [
|
||||
"apscheduler==3.11.2",
|
||||
"asyncpg==0.31.0",
|
||||
"bcrypt==5.0.0",
|
||||
"dotenv==0.9.9",
|
||||
"email-validator==2.3.0",
|
||||
"fastapi==0.128.7",
|
||||
"faststream[rabbit]==0.6.6",
|
||||
"granian==2.6.1",
|
||||
"hvac==2.4.0",
|
||||
"itsdangerous==2.2.0",
|
||||
"orjson==3.11.7",
|
||||
"pydantic-settings==2.12.0",
|
||||
"python-jose==3.5.0",
|
||||
"python-ulid==3.1.0",
|
||||
"redis==7.2.0",
|
||||
"sqlalchemy==2.0.46",
|
||||
"uvloop==0.22.1; platform_system != 'Windows'",
|
||||
]
|
||||
1
src/application/abstractions/__init__.py
Normal file
1
src/application/abstractions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.abstractions.i_unit_of_work import IUnitOfWork
|
||||
19
src/application/abstractions/i_unit_of_work.py
Normal file
19
src/application/abstractions/i_unit_of_work.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
from typing import Protocol, runtime_checkable
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IUnitOfWork(Protocol):
|
||||
async def __aenter__(self) -> "IUnitOfWork": ...
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
|
||||
|
||||
async def commit(self) -> None: ...
|
||||
async def rollback(self) -> None: ...
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository: ...
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository: ...
|
||||
|
||||
2
src/application/abstractions/repositories/__init__.py
Normal file
2
src/application/abstractions/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.abstractions.repositories.i_user_repository import IUserRepository
|
||||
from src.application.abstractions.repositories.i_session_repository import ISessionRepository
|
||||
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from src.application.domain.entities import SessionEntity
|
||||
|
||||
|
||||
|
||||
class ISessionRepository(ABC):
|
||||
@abstractmethod
|
||||
async def get_by_sid(self, sid: str) -> SessionEntity | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> SessionEntity | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,19 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from src.application.domain.entities import UserEntity
|
||||
|
||||
|
||||
class IUserRepository(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, email: str, password_hash: str) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> UserEntity:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
raise NotImplementedError
|
||||
6
src/application/commands/__init__.py
Normal file
6
src/application/commands/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.application.commands.user_registration_complete import UserRegistrationCompleteCommand
|
||||
from src.application.commands.user_login_complete import UserLoginCompleteCommand
|
||||
from src.application.commands.user_logout import UserLogoutCommand
|
||||
from src.application.commands.jwt_refresh import JwtRefreshCommand
|
||||
from src.application.commands.user_registration_start import UserRegistrationStartCommand
|
||||
from src.application.commands.user_login_start import UserLoginStartCommand
|
||||
70
src/application/commands/jwt_refresh.py
Normal file
70
src/application/commands/jwt_refresh.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
||||
from src.application.domain.dto import RefreshTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class JwtRefreshCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._jwt_service = jwt_service
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, *, refresh_token: str, ip: str | None, user_agent: str | None) -> tuple[str, str]:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
|
||||
|
||||
sid = payload.sid
|
||||
user_id = payload.sub
|
||||
jti = payload.jti
|
||||
|
||||
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
|
||||
if sess is None:
|
||||
raise ApplicationException(status_code=401, message='Session not found')
|
||||
|
||||
if sess.revoked_at is not None:
|
||||
raise ApplicationException(status_code=401, message='Session revoked')
|
||||
|
||||
if sess.refresh_expires_at <= now:
|
||||
raise ApplicationException(status_code=401, message='Session expired')
|
||||
|
||||
if str(sess.user_id) != str(user_id):
|
||||
raise ApplicationException(status_code=401, message='Invalid session subject')
|
||||
|
||||
ok = await self._hash_service.verify(
|
||||
plain_value=jti,
|
||||
hashed_value=sess.refresh_jti_hash,
|
||||
)
|
||||
if not ok:
|
||||
await self._unit_of_work.session_repository.revoke_by_sid(sid=sid, now=now)
|
||||
raise ApplicationException(status_code=401, message='Refresh token reuse detected')
|
||||
|
||||
new_jti = str(ULID())
|
||||
new_jti_hash = await self._hash_service.hash(value=new_jti)
|
||||
new_refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
|
||||
|
||||
rotated = await self._unit_of_work.session_repository.rotate_refresh_if_match(
|
||||
sid=sid,
|
||||
old_jti_hash=sess.refresh_jti_hash,
|
||||
new_jti_hash=new_jti_hash,
|
||||
new_refresh_expires_at=new_refresh_expires_at,
|
||||
now=now,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if not rotated:
|
||||
raise ApplicationException(status_code=401, message='Refresh already rotated')
|
||||
|
||||
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
|
||||
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)
|
||||
|
||||
self._logger.info(f'Tokens refreshed (user_id={user_id}, sid={sid})')
|
||||
return access, refresh
|
||||
117
src/application/commands/user_login_complete.py
Normal file
117
src/application/commands/user_login_complete.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
|
||||
from src.application.domain.dto import UserLoginDto
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UserLoginCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
jwt_service: IJwtService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._jwt_service = jwt_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
email: str,
|
||||
password: str,
|
||||
code: str,
|
||||
device_id: str,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
) -> UserLoginDto:
|
||||
email = (email or '').strip().lower()
|
||||
code = (code or '').strip()
|
||||
|
||||
code_key = f'login:code:{code}'
|
||||
email_key = f'login:email:{email}'
|
||||
|
||||
cached_email = await self._cache.get(code_key)
|
||||
if not cached_email:
|
||||
self._logger.info(f'Login failed: code not found (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_email != email:
|
||||
self._logger.info(f'Login failed: code-email mismatch (email={email}, cached_email={cached_email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(email_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Login failed: email key missing (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(hashed_value=code_hash, plain_value=code)
|
||||
if not ok:
|
||||
self._logger.info(f'Login failed: code hash mismatch (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
user: UserEntity = await self._unit_of_work.user_repository.get_user_by_email(email=email)
|
||||
|
||||
ok = await self._hash_service.verify(plain_value=password, hashed_value=user.password_hash)
|
||||
if not ok:
|
||||
self._logger.warning(f'{user.id} login failed: invalid credentials')
|
||||
raise ApplicationException(status_code=401, message='Invalid credentials')
|
||||
|
||||
try:
|
||||
await self._cache.delete(code_key)
|
||||
await self._cache.delete(email_key)
|
||||
except Exception as e:
|
||||
self._logger.warning(f'Login cleanup failed (email={email}): {e}')
|
||||
|
||||
sid = str(ULID())
|
||||
jti = str(ULID())
|
||||
|
||||
refresh_jti_hash = await self._hash_service.hash(value=jti)
|
||||
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
|
||||
|
||||
await self._unit_of_work.session_repository.upsert_by_device(
|
||||
user_id=user.id,
|
||||
device_id=device_id,
|
||||
sid=sid,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
ip=ip,
|
||||
now=now,
|
||||
)
|
||||
|
||||
access_token = await self._jwt_service.create_access_token(user_id=user.id, sid=sid)
|
||||
refresh_token = await self._jwt_service.create_refresh_token(user_id=user.id, sid=sid, refresh_jti=jti)
|
||||
|
||||
return UserLoginDto(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified=user.kyc_verified,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
133
src/application/commands/user_login_start.py
Normal file
133
src/application/commands/user_login_start.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import secrets
|
||||
from datetime import timezone, datetime
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ICache, ILogger, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UserLoginStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
|
||||
@transactional
|
||||
async def __call__(self, email: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
EMAIL_PREFIX = 'login:email:'
|
||||
CODE_PREFIX = 'login:code:'
|
||||
LOCK_PREFIX = 'login:lock:'
|
||||
|
||||
email = (email or '').strip().lower()
|
||||
if not email:
|
||||
self._logger.info('Login start failed: empty email')
|
||||
raise ApplicationException(400, 'Invalid email')
|
||||
|
||||
exists = await self._unit_of_work.user_repository.exists_by_email(email)
|
||||
if not exists:
|
||||
self._logger.info(f'Login failed: email already registered ({email})')
|
||||
raise ApplicationException(404, 'Email registered')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{email}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Login start throttled by lock ({email})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
email_key = f'{EMAIL_PREFIX}{email}'
|
||||
|
||||
existing = await self._cache.get(email_key)
|
||||
if existing:
|
||||
self._logger.info(f'Login start denied: code already exists for {email}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, email, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Login start failed: cannot save code hash for {email}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'auth-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'login',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'payload: {payload})')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(email_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for {email}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
self._logger.info(f'login code created for {email}')
|
||||
return True
|
||||
|
||||
self._logger.error(f'login start failed: code space exhausted for {email}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
28
src/application/commands/user_logout.py
Normal file
28
src/application/commands/user_logout.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timezone
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import RefreshTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UserLogoutCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, jwt_service: IJwtService, logger: ILogger):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._jwt_service = jwt_service
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
async def __call__(self, *, refresh_token: str | None) -> None:
|
||||
if not refresh_token:
|
||||
return
|
||||
try:
|
||||
payload: RefreshTokenPayload = self._jwt_service.decode_refresh_token(refresh_token)
|
||||
except ApplicationException:
|
||||
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
|
||||
return
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
await self._unit_of_work.session_repository.revoke_by_sid(sid=payload.sid, now=now)
|
||||
self._logger.info(f'Logout: session revoked (sid={payload.sid}, user_id={payload.sub})')
|
||||
121
src/application/commands/user_registration_complete.py
Normal file
121
src/application/commands/user_registration_complete.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
|
||||
from src.application.domain.dto import UserCreatedDto
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UserRegistrationCompleteCommand:
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
jwt_service: IJwtService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._hash_service = hash_service
|
||||
self._jwt_service = jwt_service
|
||||
self._logger = logger
|
||||
|
||||
|
||||
@transactional
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
email: str,
|
||||
password: str,
|
||||
device_id: str,
|
||||
code: str,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
) -> UserCreatedDto:
|
||||
|
||||
email = (email or '').strip().lower()
|
||||
code = (code or '').strip()
|
||||
|
||||
code_key = f'reg:code:{code}'
|
||||
email_key = f'reg:email:{email}'
|
||||
|
||||
cached_email = await self._cache.get(code_key)
|
||||
if not cached_email:
|
||||
self._logger.info(f'Registration failed: code not found (email={email}, code={code})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_email != email:
|
||||
self._logger.info(f'Registration failed: code-email mismatch (email={email}, cached_email={cached_email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
code_hash = await self._cache.get(email_key)
|
||||
if not code_hash:
|
||||
self._logger.info(f'Registration failed: email key missing (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
ok = await self._hash_service.verify(
|
||||
hashed_value=code_hash,
|
||||
plain_value=code,
|
||||
)
|
||||
if not ok:
|
||||
self._logger.info(f'Registration failed: code hash mismatch (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
deleted_code = await self._cache.delete(code_key)
|
||||
deleted_email = await self._cache.delete(email_key)
|
||||
|
||||
if not deleted_code or not deleted_email:
|
||||
self._logger.info(
|
||||
f'Registration cleanup: keys already missing '
|
||||
f'(email={email}, deleted_code={deleted_code}, deleted_email={deleted_email})'
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
password_hash = await self._hash_service.hash(value=password)
|
||||
|
||||
user: UserEntity = await self._unit_of_work.user_repository.create_user(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
)
|
||||
|
||||
sid = str(ULID())
|
||||
jti = str(ULID())
|
||||
|
||||
refresh_jti_hash = await self._hash_service.hash(value=jti)
|
||||
refresh_expires_at = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
|
||||
|
||||
await self._unit_of_work.session_repository.upsert_by_device(
|
||||
user_id=user.id,
|
||||
device_id=device_id,
|
||||
sid=sid,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
ip=ip,
|
||||
now=now,
|
||||
)
|
||||
|
||||
access_token = await self._jwt_service.create_access_token(
|
||||
user_id=user.id,
|
||||
sid=sid,
|
||||
)
|
||||
refresh_token = await (
|
||||
self._jwt_service.create_refresh_token(
|
||||
user_id=user.id,
|
||||
sid=sid,
|
||||
refresh_jti=jti,
|
||||
))
|
||||
|
||||
self._logger.info(f'User registered successfully user_id={user.id} device_id={device_id} sid={sid}')
|
||||
|
||||
return UserCreatedDto(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
133
src/application/commands/user_registration_start.py
Normal file
133
src/application/commands/user_registration_start.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, ILogger, ICache, IQueueMessanger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class UserRegistrationStartCommand:
|
||||
def __init__(
|
||||
self,
|
||||
hash_service: IHashService,
|
||||
cache: ICache,
|
||||
unit_of_work: IUnitOfWork,
|
||||
logger: ILogger,
|
||||
messanger: IQueueMessanger,
|
||||
):
|
||||
self._hash_service = hash_service
|
||||
self._unit_of_work = unit_of_work
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
self._messanger = messanger
|
||||
|
||||
|
||||
@transactional
|
||||
async def __call__(self, email: str) -> bool:
|
||||
TTL = 300
|
||||
LOCK_TTL = 30
|
||||
MAX_ATTEMPTS = 20
|
||||
|
||||
EMAIL_PREFIX = 'reg:email:'
|
||||
CODE_PREFIX = 'reg:code:'
|
||||
LOCK_PREFIX = 'reg:lock:'
|
||||
|
||||
email = (email or '').strip().lower()
|
||||
if not email:
|
||||
self._logger.info('Registration start failed: empty email')
|
||||
raise ApplicationException(400, 'Invalid email')
|
||||
|
||||
exists = await self._unit_of_work.user_repository.exists_by_email(email)
|
||||
if exists:
|
||||
self._logger.info(f'Registration failed: email already registered ({email})')
|
||||
raise ApplicationException(409, 'Email already registered')
|
||||
|
||||
trace_id = trace_id_var.get()
|
||||
if not trace_id or trace_id == 'N/A':
|
||||
trace_id = None
|
||||
|
||||
lock_key = f'{LOCK_PREFIX}{email}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', ttl=LOCK_TTL)
|
||||
if not locked:
|
||||
self._logger.info(f'Registration start throttled by lock ({email})')
|
||||
raise ApplicationException(429, 'Too many requests. Please wait.')
|
||||
|
||||
try:
|
||||
email_key = f'{EMAIL_PREFIX}{email}'
|
||||
|
||||
existing = await self._cache.get(email_key)
|
||||
if existing:
|
||||
self._logger.info(f'Registration start denied: code already exists for {email}')
|
||||
raise ApplicationException(429, 'Code already sent. Please wait before retrying.')
|
||||
|
||||
for _ in range(MAX_ATTEMPTS):
|
||||
code = f'{secrets.randbelow(1_000_000):06d}'
|
||||
|
||||
code_key = f'{CODE_PREFIX}{code}'
|
||||
|
||||
code_hash = await self._hash_service.hash(code)
|
||||
|
||||
reserved = await self._cache.set_nx(code_key, email, ttl=TTL)
|
||||
if not reserved:
|
||||
continue
|
||||
|
||||
saved = await self._cache.set(email_key, code_hash, ttl=TTL)
|
||||
if not saved:
|
||||
await self._cache.delete(code_key)
|
||||
self._logger.error(f'Registration start failed: cannot save code hash for {email}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
message_id = str(ULID())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
metadata = {
|
||||
'trace_id': trace_id,
|
||||
'source': 'auth-service',
|
||||
'timestamp': now,
|
||||
'message_id': message_id,
|
||||
}
|
||||
|
||||
payload = {
|
||||
'email': email,
|
||||
'code': code,
|
||||
'ttl_seconds': TTL,
|
||||
}
|
||||
|
||||
message = {
|
||||
'event': 'registration',
|
||||
'payload': payload,
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'payload: {payload})')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
message=message,
|
||||
persist=True,
|
||||
correlation_id=trace_id,
|
||||
message_id=message_id,
|
||||
headers={'trace_id': trace_id} if trace_id else None,
|
||||
)
|
||||
except Exception as exception:
|
||||
try:
|
||||
await self._cache.delete(email_key)
|
||||
await self._cache.delete(code_key)
|
||||
except Exception as rollback_err:
|
||||
self._logger.error(f'Publish failed and rollback cache failed for {email}: {str(rollback_err)}')
|
||||
|
||||
self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
self._logger.info(f'Registration code created for {email}')
|
||||
return True
|
||||
|
||||
self._logger.error(f'Registration start failed: code space exhausted for {email}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
7
src/application/contracts/__init__.py
Normal file
7
src/application/contracts/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from src.application.contracts.i_hash_service import IHashService
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
from src.application.contracts.i_user_service import IUserService
|
||||
from src.application.contracts.i_jwt_service import IJwtService
|
||||
from src.application.contracts.i_csrf_service import ICsrfService
|
||||
from src.application.contracts.i_cache import ICache
|
||||
from src.application.contracts.i_queue_messanger import IQueueMessanger
|
||||
20
src/application/contracts/i_cache.py
Normal file
20
src/application/contracts/i_cache.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ICache(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
26
src/application/contracts/i_csrf_service.py
Normal file
26
src/application/contracts/i_csrf_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Mapping
|
||||
|
||||
|
||||
class ICsrfService(ABC):
|
||||
@abstractmethod
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def verify_pair(
|
||||
self,
|
||||
cookie_token: Optional[str],
|
||||
header_token: Optional[str],
|
||||
expected_subject: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
12
src/application/contracts/i_hash_service.py
Normal file
12
src/application/contracts/i_hash_service.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IHashService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def hash(self, value: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
raise NotImplementedError
|
||||
22
src/application/contracts/i_jwt_service.py
Normal file
22
src/application/contracts/i_jwt_service.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
|
||||
|
||||
|
||||
class IJwtService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def create_access_token(self, user_id: str, sid: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def create_refresh_token(self, user_id: str, sid: str, refresh_jti: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
|
||||
raise NotImplementedError
|
||||
68
src/application/contracts/i_logger.py
Normal file
68
src/application/contracts/i_logger.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Protocol, Optional, Callable
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
|
||||
|
||||
class ILogger(Protocol):
|
||||
"""Interface for synchronous logger with ContextVar support for trace_id."""
|
||||
|
||||
log_format: LogFormat
|
||||
min_level: LogLevel
|
||||
id_generator: Optional[Callable[[], str]]
|
||||
instance_id: str
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
"""Set log format using LogFormat enum"""
|
||||
...
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
"""Set minimum log level"""
|
||||
...
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
"""Create and set new trace_id in context"""
|
||||
...
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
"""Set existing trace_id in context"""
|
||||
...
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
"""Get current trace_id from context"""
|
||||
...
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
"""Clear the trace_id in the context"""
|
||||
...
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
"""Set service instance id (ULID recommended)"""
|
||||
...
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
"""Get current service instance id"""
|
||||
...
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log debug message"""
|
||||
...
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log info message"""
|
||||
...
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log warning message"""
|
||||
...
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log error message"""
|
||||
...
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Log critical message"""
|
||||
...
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
"""Log exception with traceback"""
|
||||
...
|
||||
40
src/application/contracts/i_queue_messanger.py
Normal file
40
src/application/contracts/i_queue_messanger.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
class IQueueMessanger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool = True,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
14
src/application/contracts/i_user_service.py
Normal file
14
src/application/contracts/i_user_service.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from src.application.domain.dto import UserCreatedDto, UserLoginDto
|
||||
|
||||
|
||||
class IUserService(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def registration(self, email: str, password: str) -> UserCreatedDto:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def login(self, email: str, password: str) -> UserLoginDto:
|
||||
raise NotImplementedError
|
||||
3
src/application/domain/dto/__init__.py
Normal file
3
src/application/domain/dto/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.application.domain.dto.user import UserCreatedDto, UserLoginDto
|
||||
from src.application.domain.dto.token import AccessTokenPayload, RefreshTokenPayload, AuthContext
|
||||
from src.application.domain.dto.keys import JwtKeySet, JwtKeyPair
|
||||
21
src/application/domain/dto/keys.py
Normal file
21
src/application/domain/dto/keys.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtKeyPair:
|
||||
kid: str
|
||||
private_key_pem: str
|
||||
public_key_pem: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtKeySet:
|
||||
active: JwtKeyPair
|
||||
previous: Optional[JwtKeyPair] = None
|
||||
|
||||
def public_keys_by_kid(self) -> Dict[str, str]:
|
||||
out = {self.active.kid: self.active.public_key_pem}
|
||||
if self.previous:
|
||||
out[self.previous.kid] = self.previous.public_key_pem
|
||||
return out
|
||||
30
src/application/domain/dto/token.py
Normal file
30
src/application/domain/dto/token.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AccessTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class RefreshTokenPayload(BaseModel):
|
||||
sub: str
|
||||
type: str
|
||||
sid: str
|
||||
jti: str
|
||||
iat: int
|
||||
nbf: int
|
||||
exp: int
|
||||
iss: str | None = None
|
||||
aud: str | None = None
|
||||
|
||||
|
||||
class AuthContext(BaseModel):
|
||||
user_id: str
|
||||
sid: str
|
||||
token: AccessTokenPayload
|
||||
33
src/application/domain/dto/user.py
Normal file
33
src/application/domain/dto/user.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, date
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserCreatedDto:
|
||||
id: str
|
||||
email: str
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserLoginDto:
|
||||
id: str | None = None
|
||||
email: str | None = None
|
||||
first_name: str | None = None
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
crypto_wallet: str | None = None
|
||||
phone: str | None = None
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
inn: str | None = None
|
||||
kyc_verified: bool | None = None
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
|
||||
5
src/application/domain/entities/__init__.py
Normal file
5
src/application/domain/entities/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
from src.application.domain.entities.session import SessionEntity
|
||||
|
||||
|
||||
__all__ = ['UserEntity', 'SessionEntity']
|
||||
20
src/application/domain/entities/session.py
Normal file
20
src/application/domain/entities/session.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionEntity:
|
||||
sid: str
|
||||
user_id: str
|
||||
device_id: str
|
||||
|
||||
revoked_at: datetime | None
|
||||
last_seen_at: datetime
|
||||
|
||||
refresh_jti_hash: str | None
|
||||
refresh_expires_at: datetime | None
|
||||
|
||||
user_agent: str | None = None
|
||||
first_ip: str | None = None
|
||||
last_ip: str | None = None
|
||||
30
src/application/domain/entities/user.py
Normal file
30
src/application/domain/entities/user.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, datetime
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserEntity:
|
||||
id: str | None = None
|
||||
email: str | None = None
|
||||
password_hash: str | None = None
|
||||
|
||||
first_name: str | None = None
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
|
||||
crypto_wallet: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
inn: str | None = None
|
||||
|
||||
kyc_verified: bool | None = None
|
||||
is_deleted: bool | None = None
|
||||
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
2
src/application/domain/enums/__init__.py
Normal file
2
src/application/domain/enums/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
7
src/application/domain/enums/log_format.py
Normal file
7
src/application/domain/enums/log_format.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
"""Enum for supported log formats"""
|
||||
TEXT = 'text'
|
||||
JSON = 'json'
|
||||
54
src/application/domain/enums/log_level.py
Normal file
54
src/application/domain/enums/log_level.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARNING = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
EXCEPTION = 60
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"[{self.value}, '{self.name}']"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value == other.value
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value < other.value
|
||||
if isinstance(other, int):
|
||||
return self.value < other
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value <= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value <= other
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value > other.value
|
||||
if isinstance(other, int):
|
||||
return self.value > other
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value >= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value >= other
|
||||
return NotImplemented
|
||||
1
src/application/domain/exceptions/__init__.py
Normal file
1
src/application/domain/exceptions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.domain.exceptions.application_exceptions import ApplicationException
|
||||
18
src/application/domain/exceptions/application_exceptions.py
Normal file
18
src/application/domain/exceptions/application_exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class ApplicationException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status_code}: {self.message}"
|
||||
2
src/infrastructure/cache/__init__.py
vendored
Normal file
2
src/infrastructure/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.cache.client import create_redis_client
|
||||
from src.infrastructure.cache.keydb_client import KeydbCache
|
||||
16
src/infrastructure/cache/client.py
vendored
Normal file
16
src/infrastructure/cache/client.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=50,
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
20
src/infrastructure/cache/keydb_client.py
vendored
Normal file
20
src/infrastructure/cache/keydb_client.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
|
||||
|
||||
class KeydbCache(ICache):
|
||||
def __init__(self, redis_client: Redis):
|
||||
self._r = redis_client
|
||||
|
||||
async def set(self, key: str, value: str, ttl: int) -> None:
|
||||
return bool(await self._r.set(key, value, ex=ttl))
|
||||
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl, nx=True))
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
deleted = await self._r.delete(key)
|
||||
return deleted > 0
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import settings
|
||||
252
src/infrastructure/config/settings.py
Normal file
252
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from pydantic import AliasChoices, Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from src.infrastructure.vault import create_hvac_client_from_approle, read_kv2_secret
|
||||
|
||||
env_file = find_dotenv(".env")
|
||||
if env_file:
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
def normalize_vault_base_url(raw: str) -> str:
|
||||
u = raw.strip().rstrip('/')
|
||||
if not u:
|
||||
return raw.strip()
|
||||
if '://' not in u:
|
||||
return f'https://{u}'
|
||||
return u
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
VAULT_ADDR: str = Field(default='http://localhost:8200')
|
||||
VAULT_ROLE_ID: str = Field(..., description='AppRole role_id')
|
||||
VAULT_SECRET_ID: str = Field(
|
||||
...,
|
||||
description='AppRole secret_id',
|
||||
validation_alias=AliasChoices('VAULT_SECRET_ID', 'VAULT_SECRET_TOKEN'),
|
||||
)
|
||||
VAULT_NAMESPACE: str | None = Field(default=None)
|
||||
VAULT_MOUNT_POINT: str = Field(default='dev-secrets')
|
||||
|
||||
VAULT_JWT_KID_PATH: str = 'jwt/kid'
|
||||
VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids'
|
||||
JWT_KEYS_REFRESH_SECONDS: int = 3600
|
||||
|
||||
DATABASE_HOST: str
|
||||
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
|
||||
DATABASE_NAME: str
|
||||
DATABASE_USER: str
|
||||
DATABASE_PASSWORD: str
|
||||
|
||||
DATABASE_POOL_SIZE: int = 10
|
||||
DATABASE_MAX_OVERFLOW: int = 20
|
||||
DATABASE_POOL_TIMEOUT: int = 30
|
||||
DATABASE_POOL_RECYCLE: int = 3600
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
CSRF_SECRET_KEY: str = Field(min_length=32)
|
||||
|
||||
CSRF_COOKIE_SECURE: bool = False
|
||||
CSRF_COOKIE_HTTPONLY: bool = True
|
||||
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
|
||||
CSRF_COOKIE_PATH: str = '/'
|
||||
CSRF_COOKIE_DOMAIN: str | None = None
|
||||
|
||||
DOCS_USERNAME: str = 'admin'
|
||||
DOCS_PASSWORD: str = 'admin'
|
||||
|
||||
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
|
||||
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
|
||||
JWT_ISSUER: str | None = None
|
||||
JWT_AUDIENCE: str | None = None
|
||||
JWT_ALGORITHM: str = 'RS256'
|
||||
|
||||
REDIS_HOST: str = 'localhost'
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_PASSWORD: str | None = None
|
||||
REDIS_DB: int = 0
|
||||
|
||||
RABBIT_HOST: str = 'localhost'
|
||||
RABBIT_PORT: int = 5672
|
||||
RABBIT_USER: str = 'guest'
|
||||
RABBIT_PASSWORD: str = 'guest'
|
||||
RABBIT_VHOST: str = '/'
|
||||
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
|
||||
|
||||
CORS_ORIGINS: str = 'http://localhost:3000'
|
||||
CORS_ALLOW_CREDENTIALS: bool = True
|
||||
|
||||
RATE_LIMIT_REQUESTS: int = 60
|
||||
RATE_LIMIT_WINDOW: int = 60
|
||||
|
||||
LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO'
|
||||
LOG_FORMAT: Literal['JSON', 'TEXT'] = 'TEXT'
|
||||
|
||||
@field_validator('VAULT_ADDR', mode='before')
|
||||
@classmethod
|
||||
def vault_addr_scheme(cls, v):
|
||||
if v is None or not isinstance(v, str):
|
||||
return v
|
||||
return normalize_vault_base_url(v)
|
||||
|
||||
@field_validator('CSRF_COOKIE_DOMAIN', mode='before')
|
||||
@classmethod
|
||||
def empty_csrf_domain_to_none(cls, v):
|
||||
if v is None or (isinstance(v, str) and not v.strip()):
|
||||
return None
|
||||
return v
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
case_sensitive=True,
|
||||
extra='ignore',
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def load_from_vault(cls, data: dict):
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
addr_raw = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
|
||||
addr = normalize_vault_base_url(addr_raw)
|
||||
data['VAULT_ADDR'] = addr
|
||||
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
|
||||
secret_id = (
|
||||
data.get('VAULT_SECRET_ID')
|
||||
or data.get('VAULT_SECRET_TOKEN')
|
||||
or os.getenv('VAULT_SECRET_ID')
|
||||
or os.getenv('VAULT_SECRET_TOKEN')
|
||||
)
|
||||
namespace = data.get('VAULT_NAMESPACE')
|
||||
if namespace is None:
|
||||
namespace = os.getenv('VAULT_NAMESPACE')
|
||||
namespace = namespace if namespace else None
|
||||
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'dev-secrets'
|
||||
|
||||
if not role_id or not secret_id:
|
||||
raise RuntimeError(
|
||||
'VAULT_ROLE_ID and VAULT_SECRET_ID (or VAULT_SECRET_TOKEN) are required for Vault AppRole'
|
||||
)
|
||||
|
||||
data['VAULT_ROLE_ID'] = str(role_id).strip()
|
||||
data['VAULT_SECRET_ID'] = str(secret_id).strip()
|
||||
|
||||
client = create_hvac_client_from_approle(
|
||||
url=addr,
|
||||
role_id=role_id,
|
||||
secret_id=secret_id,
|
||||
namespace=namespace,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def read_secret(path: str) -> dict:
|
||||
return read_kv2_secret(client=client, mount_point=mount, path=path)
|
||||
|
||||
def read_secret_optional(path: str) -> dict:
|
||||
try:
|
||||
return read_secret(path)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
database = read_secret('database')
|
||||
csrf = read_secret('csrf')
|
||||
|
||||
db_ci = {str(k).lower(): v for k, v in database.items()}
|
||||
|
||||
def db_nonempty(key: str) -> bool:
|
||||
v = db_ci.get(key)
|
||||
if v is None:
|
||||
return False
|
||||
if isinstance(v, str) and not v.strip():
|
||||
return False
|
||||
return True
|
||||
|
||||
required_db = ['host', 'name', 'user', 'password', 'port']
|
||||
missing_db = [k for k in required_db if not db_nonempty(k)]
|
||||
if missing_db:
|
||||
raise RuntimeError(f'Vault secret database missing non-empty keys: {missing_db}')
|
||||
|
||||
data['DATABASE_HOST'] = str(db_ci['host']).strip()
|
||||
data['DATABASE_PORT'] = int(db_ci['port'])
|
||||
data['DATABASE_NAME'] = str(db_ci['name']).strip()
|
||||
data['DATABASE_USER'] = str(db_ci['user']).strip()
|
||||
data['DATABASE_PASSWORD'] = str(db_ci['password']).strip()
|
||||
|
||||
csrf_secret = None
|
||||
for entry_key, entry_val in csrf.items():
|
||||
if str(entry_key).lower() == 'key' and entry_val is not None and str(entry_val).strip():
|
||||
csrf_secret = str(entry_val).strip()
|
||||
break
|
||||
if not csrf_secret:
|
||||
raise RuntimeError(
|
||||
'Vault secret at csrf must contain a non-empty field named key (e.g. key=...)'
|
||||
)
|
||||
|
||||
data['CSRF_SECRET_KEY'] = csrf_secret
|
||||
|
||||
rabbit = read_secret_optional('rabbitmq')
|
||||
if rabbit:
|
||||
r_ci = {str(k).lower(): v for k, v in rabbit.items()}
|
||||
|
||||
def rb_set(field: str, env_key: str, *, as_int: bool = False) -> None:
|
||||
v = r_ci.get(field)
|
||||
if v is None:
|
||||
return
|
||||
if isinstance(v, str) and not v.strip():
|
||||
return
|
||||
if as_int:
|
||||
data[env_key] = int(v)
|
||||
else:
|
||||
data[env_key] = str(v).strip()
|
||||
|
||||
rb_set('host', 'RABBIT_HOST')
|
||||
rb_set('port', 'RABBIT_PORT', as_int=True)
|
||||
rb_set('user', 'RABBIT_USER')
|
||||
rb_set('password', 'RABBIT_PASSWORD')
|
||||
rb_set('vhost', 'RABBIT_VHOST')
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
return [o.strip() for o in self.CORS_ORIGINS.split(',') if o.strip()]
|
||||
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
|
||||
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||
)
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
|
||||
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
|
||||
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
|
||||
|
||||
@property
|
||||
def EXCLUDED_PATHS(self) -> List[str]:
|
||||
return ['/docs', '/redoc', '/openapi.json', '/ping', '/health']
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
1
src/infrastructure/context_vars/__init__.py
Normal file
1
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.context_vars.trace_id import trace_id_var
|
||||
4
src/infrastructure/context_vars/trace_id.py
Normal file
4
src/infrastructure/context_vars/trace_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')
|
||||
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
1
src/infrastructure/database/decorators/__init__.py
Normal file
1
src/infrastructure/database/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.decorators.transactional import transactional
|
||||
15
src/infrastructure/database/decorators/transactional.py
Normal file
15
src/infrastructure/database/decorators/transactional.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, TypeVar, ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(method)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
async with self._unit_of_work:
|
||||
return await method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
6
src/infrastructure/database/models/__init__.py
Normal file
6
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.sessions import Session
|
||||
|
||||
__all__ = ['Base', 'UserModel', 'Session']
|
||||
|
||||
19
src/infrastructure/database/models/base.py
Normal file
19
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
|
||||
for col in self.__table__.columns)
|
||||
return f"<{class_name}({attributes})>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
|
||||
for col in self.__table__.columns
|
||||
if getattr(self, col.name) is not None)
|
||||
return f"{class_name}({attributes})"
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
16
src/infrastructure/database/models/mixins/audit.py
Normal file
16
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
|
||||
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))
|
||||
50
src/infrastructure/database/models/sessions.py
Normal file
50
src/infrastructure/database/models/sessions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
from src.infrastructure.database.models import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
|
||||
|
||||
|
||||
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
sid: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
default=lambda: str(ULID()),
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
first_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
last_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
last_seen_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
|
||||
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
|
||||
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)
|
||||
28
src/infrastructure/database/models/user.py
Normal file
28
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
from sqlalchemy import Boolean, Date, String, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
|
||||
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
2
src/infrastructure/database/repositories/__init__.py
Normal file
2
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
from src.infrastructure.database.repositories.session_repository import SessionRepository
|
||||
198
src/infrastructure/database/repositories/session_repository.py
Normal file
198
src/infrastructure/database/repositories/session_repository.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities import SessionEntity
|
||||
from src.application.abstractions.repositories import ISessionRepository
|
||||
from src.infrastructure.database.models import Session
|
||||
|
||||
|
||||
class SessionRepository(ISessionRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(select(Session).where(Session.sid == sid))
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
|
||||
if m is None:
|
||||
m = Session(
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
revoked_at=None,
|
||||
last_seen_at=now,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
first_ip=ip,
|
||||
last_ip=ip,
|
||||
)
|
||||
self._session.add(m)
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
else:
|
||||
m.sid = sid
|
||||
m.revoked_at = None
|
||||
m.last_seen_at = now
|
||||
m.refresh_jti_hash = refresh_jti_hash
|
||||
m.refresh_expires_at = refresh_expires_at
|
||||
m.user_agent = user_agent
|
||||
m.last_ip = ip
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
# Интерфейс требует None -> просто делаем update и flush
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(revoked_at=now)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
|
||||
values = {'last_seen_at': now}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
res = await self._session.execute(
|
||||
update(Session)
|
||||
.where(
|
||||
Session.sid == sid,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
|
||||
)
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
return (res.rowcount or 0) > 0
|
||||
114
src/infrastructure/database/repositories/user_repository.py
Normal file
114
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def create_user(self, email: str, password_hash: str) -> UserEntity:
|
||||
user = UserModel(email=email, password_hash=password_hash)
|
||||
self._session.add(user)
|
||||
try:
|
||||
await self._session.flush()
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
created_at=user.created_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted
|
||||
)
|
||||
|
||||
except IntegrityError:
|
||||
self._logger.error(f'User already exists with email {user.email}')
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
message='User with this email already exists',
|
||||
)
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> UserEntity:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
user: UserModel | None = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
self._logger.warning(f'User not found with email {email}')
|
||||
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found',)
|
||||
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel.id)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=f'Database error: {str(exception)}',
|
||||
)
|
||||
|
||||
|
||||
42
src/infrastructure/database/unit_of_work.py
Normal file
42
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._session_repository: ISessionRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository:
|
||||
if self._session_repository is None:
|
||||
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
|
||||
return self._session_repository
|
||||
28
src/infrastructure/logger/__init__.py
Normal file
28
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat
|
||||
from src.application.domain.enums import LogLevel
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
log_levels = {
|
||||
'DEBUG': LogLevel.DEBUG,
|
||||
'INFO': LogLevel.INFO,
|
||||
'WARNING': LogLevel.WARNING,
|
||||
'ERROR': LogLevel.ERROR,
|
||||
'CRITICAL': LogLevel.CRITICAL,
|
||||
'EXCEPTION': LogLevel.EXCEPTION,
|
||||
}
|
||||
|
||||
log_formats = {
|
||||
'JSON': LogFormat.JSON,
|
||||
'TEXT': LogFormat.TEXT,
|
||||
}
|
||||
|
||||
logger = Logger(
|
||||
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
|
||||
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
|
||||
)
|
||||
|
||||
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = json.dumps(log_data, ensure_ascii=False)
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
1
src/infrastructure/messanger/__init__.py
Normal file
1
src/infrastructure/messanger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.messanger.rabbit_client import RabbitClient
|
||||
72
src/infrastructure/messanger/rabbit_client.py
Normal file
72
src/infrastructure/messanger/rabbit_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, Mapping
|
||||
from faststream.rabbit import RabbitBroker
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
class RabbitClient(IQueueMessanger):
|
||||
def __init__(self) -> None:
|
||||
self._broker = RabbitBroker(
|
||||
settings.RABBIT_URL,
|
||||
)
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._connected:
|
||||
return
|
||||
await self._broker.connect()
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._connected:
|
||||
return
|
||||
await self._broker.close()
|
||||
self._connected = False
|
||||
|
||||
async def _ensure_connected(self) -> None:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
queue=queue,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
exchange=exchange,
|
||||
routing_key=routing_key,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
3
src/infrastructure/security/__init__.py
Normal file
3
src/infrastructure/security/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.security.csrf import CsrfService
|
||||
from src.infrastructure.security.hash import HashService
|
||||
81
src/infrastructure/security/csrf.py
Normal file
81
src/infrastructure/security/csrf.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
import secrets
|
||||
from typing import Any, Optional, Mapping
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from src.application.contracts import ICsrfService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
|
||||
|
||||
class CsrfService(ICsrfService):
|
||||
COOKIE_NAME = "csrf_token"
|
||||
HEADER_NAME = "X-CSRF-Token"
|
||||
SALT = "csrf"
|
||||
TTL_SECONDS = 3600
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._serializer = URLSafeTimedSerializer(
|
||||
secret_key=settings.CSRF_SECRET_KEY,
|
||||
salt=self.SALT,
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self) -> str:
|
||||
return self.COOKIE_NAME
|
||||
|
||||
@property
|
||||
def header_name(self) -> str:
|
||||
return self.HEADER_NAME
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> int:
|
||||
return self.TTL_SECONDS
|
||||
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
payload = {
|
||||
"sub": subject,
|
||||
"nonce": secrets.token_urlsafe(32),
|
||||
}
|
||||
return self._serializer.dumps(payload)
|
||||
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
try:
|
||||
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
|
||||
except SignatureExpired:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token expired",
|
||||
)
|
||||
except BadSignature:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token invalid",
|
||||
)
|
||||
|
||||
if expected_subject is not None and data.get("sub") != expected_subject:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token subject mismatch",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
cookie_token = cookies.get(self.COOKIE_NAME)
|
||||
header_token = headers.get(self.HEADER_NAME)
|
||||
return cookie_token, header_token
|
||||
|
||||
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
|
||||
if not cookie_token or not header_token:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token missing",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message="CSRF token mismatch",
|
||||
)
|
||||
|
||||
self.verify(cookie_token, expected_subject=expected_subject)
|
||||
17
src/infrastructure/security/hash.py
Normal file
17
src/infrastructure/security/hash.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import bcrypt
|
||||
from src.application.contracts import IHashService, ILogger
|
||||
|
||||
|
||||
class HashService(IHashService):
|
||||
|
||||
def __init__(self, logger: ILogger):
|
||||
self._logger = logger
|
||||
|
||||
async def hash(self, value: str) -> str:
|
||||
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
|
||||
self._logger.info(f'Hash value {hashed_value.decode()}')
|
||||
return hashed_value.decode()
|
||||
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
self._logger.info(f'Hash value {hashed_value[:10]}')
|
||||
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())
|
||||
207
src/infrastructure/security/jwt.py
Normal file
207
src/infrastructure/security/jwt.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload, RefreshTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
class JwtService(IJwtService):
|
||||
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
|
||||
self._logger = logger
|
||||
self._key_store = key_store
|
||||
|
||||
async def create_access_token(self, user_id: str, sid: str) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = now + timedelta(seconds=int(settings.JWT_ACCESS_TTL_SECONDS))
|
||||
|
||||
payload: dict[str, object] = {
|
||||
'sub': user_id,
|
||||
'type': 'access',
|
||||
'sid': sid,
|
||||
'iat': int(now.timestamp()),
|
||||
'nbf': int(now.timestamp()),
|
||||
'exp': int(exp.timestamp()),
|
||||
}
|
||||
|
||||
if settings.JWT_ISSUER:
|
||||
payload['iss'] = settings.JWT_ISSUER
|
||||
if settings.JWT_AUDIENCE:
|
||||
payload['aud'] = settings.JWT_AUDIENCE
|
||||
|
||||
try:
|
||||
kid, private_pem = await self._key_store.get_signing_key()
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
private_pem,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
headers={'kid': kid},
|
||||
)
|
||||
|
||||
self._logger.info(f'Access token created user_id={user_id} sid={sid} kid={kid}')
|
||||
return token
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
except Exception as exception:
|
||||
self._logger.error(f'JWT access signing failed user_id={user_id} sid={sid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT signing failed')
|
||||
|
||||
async def create_refresh_token(self, user_id: str, sid: str, refresh_jti: str) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
exp = now + timedelta(seconds=int(settings.JWT_REFRESH_TTL_SECONDS))
|
||||
|
||||
payload: dict[str, object] = {
|
||||
'sub': user_id,
|
||||
'type': 'refresh',
|
||||
'sid': sid,
|
||||
'jti': refresh_jti,
|
||||
'iat': int(now.timestamp()),
|
||||
'nbf': int(now.timestamp()),
|
||||
'exp': int(exp.timestamp()),
|
||||
}
|
||||
|
||||
if settings.JWT_ISSUER:
|
||||
payload['iss'] = settings.JWT_ISSUER
|
||||
if settings.JWT_AUDIENCE:
|
||||
payload['aud'] = settings.JWT_AUDIENCE
|
||||
|
||||
try:
|
||||
kid, private_pem = await self._key_store.get_signing_key()
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
private_pem,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
headers={'kid': kid},
|
||||
)
|
||||
|
||||
self._logger.info(f'Refresh token created user_id={user_id} sid={sid} jti={refresh_jti} kid={kid}')
|
||||
return token
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
except Exception as exception:
|
||||
self._logger.error(f'JWT refresh signing failed user_id={user_id} sid={sid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT signing failed')
|
||||
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'access':
|
||||
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return AccessTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
sid=str(payload['sid']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Access token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def decode_refresh_token(self, token: str) -> RefreshTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'refresh':
|
||||
self._logger.warning(f'Refresh token invalid type received_type={payload.get('type')}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return RefreshTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='refresh',
|
||||
sid=str(payload['sid']),
|
||||
jti=str(payload['jti']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Refresh token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def _decode_and_verify(self, token: str) -> dict:
|
||||
kid: str | None = None
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
kid = header.get('kid')
|
||||
if not kid:
|
||||
self._logger.warning(f'JWT header missing kid header={header}')
|
||||
raise ApplicationException(status_code=401, message='Missing token header: kid')
|
||||
|
||||
if header.get('alg') != settings.JWT_ALGORITHM:
|
||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={header.get('alg')} expected_alg={settings.JWT_ALGORITHM}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token algorithm')
|
||||
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
if not public_pem:
|
||||
self._logger.info(f'JWT kid cache miss kid={kid} refreshing keystore')
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Unknown token kid')
|
||||
|
||||
options = {
|
||||
'verify_signature': True,
|
||||
'verify_exp': True,
|
||||
'verify_nbf': True,
|
||||
'verify_iat': True,
|
||||
'verify_aud': bool(settings.JWT_AUDIENCE),
|
||||
'verify_iss': bool(settings.JWT_ISSUER),
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'require_sub': True,
|
||||
'require_sid': True,
|
||||
'require_type': True,
|
||||
'leeway': 10,
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE or None,
|
||||
issuer=settings.JWT_ISSUER or None,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if options.get('require_sid') and 'sid' not in payload:
|
||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: sid')
|
||||
|
||||
if options.get('require_type') and 'type' not in payload:
|
||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError as exception:
|
||||
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except JWTError as exception:
|
||||
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token')
|
||||
|
||||
except Exception as exception:
|
||||
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT decode failed')
|
||||
1
src/infrastructure/utils/__init__.py
Normal file
1
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.utils.instance_id import generate_instance_id
|
||||
14
src/infrastructure/utils/instance_id.py
Normal file
14
src/infrastructure/utils/instance_id.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
"""
|
||||
Generate a process-wide instance id in ULID format.
|
||||
|
||||
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
|
||||
"""
|
||||
|
||||
|
||||
return str(ULID())
|
||||
|
||||
|
||||
3
src/infrastructure/vault/__init__.py
Normal file
3
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.vault.utils import create_hvac_client_from_approle, read_kv2_secret
|
||||
from src.infrastructure.vault.keys import JwtKeyStore
|
||||
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
||||
118
src/infrastructure/vault/keys.py
Normal file
118
src/infrastructure/vault/keys.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from src.application.domain.dto import JwtKeySet, JwtKeyPair
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.vault.utils import create_hvac_client_from_approle, read_kv2_secret
|
||||
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
|
||||
_instance: "JwtKeyStore | None" = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str,
|
||||
vault_role_id: str,
|
||||
vault_secret_id: str,
|
||||
vault_namespace: str | None,
|
||||
mount_point: str,
|
||||
kid_path: str = 'jwt/kid',
|
||||
kids_prefix: str = 'jwt/kids',
|
||||
timeout_seconds: int = 5,
|
||||
):
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self._vault_addr = vault_addr
|
||||
self._vault_role_id = vault_role_id
|
||||
self._vault_secret_id = vault_secret_id
|
||||
self._vault_namespace = vault_namespace
|
||||
self._timeout = timeout_seconds
|
||||
|
||||
self._mount = mount_point
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._keyset: JwtKeySet | None = None
|
||||
self._last_refresh_at: datetime | None = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
|
||||
if cls._instance is None:
|
||||
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _read_keyset_sync(self) -> JwtKeySet:
|
||||
client = create_hvac_client_from_approle(
|
||||
url=self._vault_addr,
|
||||
role_id=self._vault_role_id,
|
||||
secret_id=self._vault_secret_id,
|
||||
namespace=self._vault_namespace,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
||||
active_kid = kids.get('active')
|
||||
previous_kid = kids.get('previous')
|
||||
|
||||
if not active_kid:
|
||||
raise RuntimeError('Vault jwt/kid secret missing active')
|
||||
|
||||
active_pair = self._read_keypair_sync(client, active_kid)
|
||||
|
||||
prev_pair = None
|
||||
if previous_kid and previous_kid != active_kid:
|
||||
prev_pair = self._read_keypair_sync(client, previous_kid)
|
||||
|
||||
return JwtKeySet(active=active_pair, previous=prev_pair)
|
||||
|
||||
def _read_keypair_sync(self, client, kid: str) -> JwtKeyPair:
|
||||
data = read_kv2_secret(
|
||||
client=client,
|
||||
mount_point=self._mount,
|
||||
path=f'{self._kids_prefix}/{kid}',
|
||||
)
|
||||
priv = data.get('private_key')
|
||||
pub = data.get('public_key')
|
||||
if not priv or not pub:
|
||||
raise RuntimeError(f'Vault jwt/kids/{kid} missing private_key/public_key')
|
||||
return JwtKeyPair(kid=kid, private_key_pem=priv, public_key_pem=pub)
|
||||
|
||||
|
||||
async def refresh(self) -> JwtKeySet:
|
||||
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
||||
async with self._lock:
|
||||
self._keyset = keyset
|
||||
self._last_refresh_at = datetime.now(timezone.utc)
|
||||
return keyset
|
||||
|
||||
async def get_signing_key(self) -> tuple[str, str]:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.active.kid, ks.active.private_key_pem
|
||||
|
||||
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.public_keys_by_kid().get(kid)
|
||||
|
||||
async def last_refresh_at(self) -> datetime | None:
|
||||
async with self._lock:
|
||||
return self._last_refresh_at
|
||||
|
||||
async def _get_or_refresh(self) -> JwtKeySet:
|
||||
async with self._lock:
|
||||
ks = self._keyset
|
||||
return ks if ks else await self.refresh()
|
||||
24
src/infrastructure/vault/scheduler.py
Normal file
24
src/infrastructure/vault/scheduler.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
store.refresh,
|
||||
trigger=IntervalTrigger(seconds=refresh_seconds),
|
||||
id="jwt_keys_refresh",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
|
||||
return scheduler
|
||||
30
src/infrastructure/vault/utils.py
Normal file
30
src/infrastructure/vault/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
import hvac
|
||||
|
||||
|
||||
def create_hvac_client_from_approle(
|
||||
*,
|
||||
url: str,
|
||||
role_id: str,
|
||||
secret_id: str,
|
||||
namespace: str | None = None,
|
||||
timeout: int = 5,
|
||||
) -> hvac.Client:
|
||||
kwargs: dict = {'url': url, 'timeout': timeout}
|
||||
if namespace:
|
||||
kwargs['namespace'] = namespace
|
||||
client = hvac.Client(**kwargs)
|
||||
client.auth.approle.login(role_id=role_id, secret_id=secret_id)
|
||||
if not client.is_authenticated():
|
||||
raise RuntimeError(
|
||||
'Vault AppRole authentication failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID'
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point=mount_point,
|
||||
path=path,
|
||||
)
|
||||
return secret["data"]["data"]
|
||||
154
src/main.py
Normal file
154
src/main.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI, status
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.cache import create_redis_client
|
||||
from src.infrastructure.config.settings import get_settings
|
||||
from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler
|
||||
from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.dependencies import get_rabbit
|
||||
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
from src.presentation.routing import v1_router
|
||||
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
|
||||
if not (user_ok and pass_ok):
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message='Unauthorized',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
instance_id = generate_instance_id()
|
||||
logger.set_instance_id(instance_id)
|
||||
logger.info(f'Auth service instance started with id {instance_id}')
|
||||
|
||||
jwt_store = JwtKeyStore(
|
||||
vault_addr=settings.VAULT_ADDR,
|
||||
vault_role_id=settings.VAULT_ROLE_ID,
|
||||
vault_secret_id=settings.VAULT_SECRET_ID,
|
||||
vault_namespace=settings.VAULT_NAMESPACE,
|
||||
mount_point=settings.VAULT_MOUNT_POINT,
|
||||
kid_path=settings.VAULT_JWT_KID_PATH,
|
||||
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
||||
)
|
||||
|
||||
await jwt_store.refresh()
|
||||
|
||||
jwt_scheduler = start_jwt_keys_scheduler(
|
||||
jwt_store,
|
||||
refresh_seconds=settings.JWT_KEYS_REFRESH_SECONDS,
|
||||
)
|
||||
|
||||
app.state.jwt_key_store = jwt_store
|
||||
app.state.jwt_keys_scheduler = jwt_scheduler
|
||||
|
||||
redis_client = create_redis_client()
|
||||
|
||||
await get_rabbit().connect()
|
||||
logger.info('Rabbit connected')
|
||||
|
||||
try:
|
||||
await redis_client.ping()
|
||||
app.state.redis = redis_client
|
||||
logger.info('Redis connected')
|
||||
yield
|
||||
finally:
|
||||
logger.info('Shutting down...')
|
||||
|
||||
sched = getattr(app.state, 'jwt_keys_scheduler', None)
|
||||
if sched:
|
||||
sched.shutdown(wait=False)
|
||||
|
||||
await redis_client.close()
|
||||
await redis_client.connection_pool.disconnect()
|
||||
await get_rabbit().close()
|
||||
|
||||
logger.info('Redis disconnected')
|
||||
logger.info('API stopped')
|
||||
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=None,
|
||||
lifespan=lifespan,
|
||||
title='Bitforce. Auth Service',
|
||||
version='1.0.0',
|
||||
description='',
|
||||
license_info={
|
||||
'name': 'MIT',
|
||||
'url': 'https://opensource.org/licenses/MIT',
|
||||
},
|
||||
)
|
||||
|
||||
app.add_exception_handler(ApplicationException, application_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
app.include_router(v1_router)
|
||||
|
||||
app.add_middleware(TraceIDMiddleware, logger=logger)
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts=True,
|
||||
hsts_preload=False,
|
||||
frame_options='DENY',
|
||||
referrer_policy='strict-origin-when-cross-origin',
|
||||
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list(),
|
||||
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
|
||||
@app.get('/docs', include_in_schema=False)
|
||||
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom Swagger documentation, optionally protected with basic authentication.'''
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
|
||||
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
|
||||
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
|
||||
)
|
||||
|
||||
|
||||
@app.get('/redoc', include_in_schema=False)
|
||||
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
'''Custom ReDoc documentation, optionally protected with basic authentication.'''
|
||||
return get_redoc_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
|
||||
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
|
||||
)
|
||||
|
||||
|
||||
@app.post('/ping')
|
||||
async def ping() -> dict[str, str]:
|
||||
return {
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
2
src/presentation/decorators/__init__.py
Normal file
2
src/presentation/decorators/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.decorators.csrf import csrf_protect
|
||||
from src.presentation.decorators.rate_limit import rate_limit, _email_rl_key as email_rl_key
|
||||
36
src/presentation/decorators/auth.py
Normal file
36
src/presentation/decorators/auth.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from src.application.contracts import IJwtService
|
||||
from src.application.domain.dto import AuthContext
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.presentation.dependencies import get_jwt_service
|
||||
|
||||
|
||||
def _extract_access_token(request: Request) -> str | None:
|
||||
token = request.cookies.get("access_token")
|
||||
|
||||
if token:
|
||||
return token
|
||||
|
||||
auth = request.headers.get("Authorization")
|
||||
if auth:
|
||||
scheme, param = get_authorization_scheme_param(auth)
|
||||
if scheme.lower() == "bearer" and param:
|
||||
return param
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def require_access_token(
|
||||
request: Request,
|
||||
jwt_service: IJwtService = Depends(get_jwt_service), # твой DI
|
||||
) -> AuthContext:
|
||||
token = _extract_access_token(request)
|
||||
if not token:
|
||||
raise ApplicationException(status_code=401, message="Not authenticated")
|
||||
|
||||
payload = jwt_service.decode_access_token(token)
|
||||
if payload.type != "access":
|
||||
raise ApplicationException(status_code=401, message="Invalid token type")
|
||||
|
||||
return AuthContext(user_id=payload.sub, sid=payload.sid, token=payload)
|
||||
61
src/presentation/decorators/csrf.py
Normal file
61
src/presentation/decorators/csrf.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, Any, Optional, Annotated
|
||||
from fastapi import Request, Header
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.security import CsrfService
|
||||
|
||||
|
||||
def csrf_protect(
|
||||
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
|
||||
):
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
|
||||
if not has_request:
|
||||
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
|
||||
|
||||
has_header = any(p.name == 'x_csrf_token' for p in params)
|
||||
if not has_header:
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
name='x_csrf_token',
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
|
||||
)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request | None = kwargs.get('request')
|
||||
if request is None:
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
raise ApplicationException(
|
||||
status_code=500,
|
||||
message='Request is required for CSRF protection',
|
||||
)
|
||||
|
||||
csrf = CsrfService()
|
||||
|
||||
cookie_token, _ = csrf.extract(request.cookies, request.headers)
|
||||
header_token = kwargs.get('x_csrf_token')
|
||||
|
||||
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
|
||||
csrf.verify_pair(cookie_token, header_token, expected_subject)
|
||||
|
||||
kwargs.pop('x_csrf_token', None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=params)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
171
src/presentation/decorators/rate_limit.py
Normal file
171
src/presentation/decorators/rate_limit.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import inspect
|
||||
import hashlib
|
||||
from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, runtime_checkable
|
||||
from fastapi import Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.dependencies import get_redis
|
||||
|
||||
|
||||
def _find_request(args: tuple[Any, ...], kwargs: dict[str, Any]) -> Request:
|
||||
req = kwargs.get('request')
|
||||
if isinstance(req, Request):
|
||||
return req
|
||||
for a in args:
|
||||
if isinstance(a, Request):
|
||||
return a
|
||||
raise RuntimeError('rate_limit decorator requires fastapi.Request argument')
|
||||
|
||||
|
||||
def _client_ip(request: Request) -> str:
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
if xff:
|
||||
return xff.split(',')[0].strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return 'unknown'
|
||||
|
||||
|
||||
_LUA_INCR_EXPIRE_TTL = '''
|
||||
local key = KEYS[1]
|
||||
local window = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, window)
|
||||
end
|
||||
|
||||
local ttl = redis.call('TTL', key)
|
||||
return { current, ttl }
|
||||
'''
|
||||
|
||||
|
||||
Scope = Literal['ip', 'device', 'user', 'key']
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder1(Protocol):
|
||||
def __call__(self, request: Request) -> str: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class KeyBuilder3(Protocol):
|
||||
def __call__(self, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
KeyBuilder = KeyBuilder1 | KeyBuilder3
|
||||
|
||||
|
||||
def _call_key_builder(builder: KeyBuilder, request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
try:
|
||||
sig = inspect.signature(builder)
|
||||
if len(sig.parameters) >= 3:
|
||||
return builder(request, args, kwargs)
|
||||
return builder(request)
|
||||
except Exception as e:
|
||||
try:
|
||||
return builder(request, args, kwargs)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
def _email_rl_key(request: Request, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
||||
|
||||
body = kwargs.get('body')
|
||||
if body is None and args:
|
||||
for a in args:
|
||||
if hasattr(a, 'email'):
|
||||
body = a
|
||||
break
|
||||
|
||||
email = (getattr(body, 'email', '') or '').strip().lower()
|
||||
if not email:
|
||||
email = _client_ip(request)
|
||||
|
||||
digest = hashlib.sha256(email.encode('utf-8')).hexdigest()[:24]
|
||||
return f'email:{digest}'
|
||||
|
||||
def rate_limit(
|
||||
*,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
scope: Scope = 'ip',
|
||||
key_prefix: str = 'rl',
|
||||
key_builder: Optional[KeyBuilder] = None,
|
||||
fail_open: bool = True,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
|
||||
|
||||
if limit <= 0:
|
||||
raise ValueError('rate_limit: limit must be > 0')
|
||||
if window_seconds <= 0:
|
||||
raise ValueError('rate_limit: window_seconds must be > 0')
|
||||
if scope == 'key' and not key_builder:
|
||||
raise ValueError('rate_limit: scope="key" requires key_builder')
|
||||
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any):
|
||||
request = _find_request(args, kwargs)
|
||||
logger: ILogger = get_logger()
|
||||
|
||||
if scope == 'ip':
|
||||
ident = _client_ip(request)
|
||||
elif scope == 'device':
|
||||
ident = request.cookies.get('device_id') or _client_ip(request)
|
||||
elif scope == 'user':
|
||||
user = getattr(request.state, 'user', None)
|
||||
user_id = getattr(user, 'id', None) if user else None
|
||||
ident = str(user_id) if user_id else _client_ip(request)
|
||||
else:
|
||||
try:
|
||||
ident = _call_key_builder(key_builder, request, args, kwargs) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit key_builder failed error={str(e)}')
|
||||
raise ApplicationException(500, 'Rate limiter key_builder failed')
|
||||
|
||||
route = request.url.path
|
||||
method = request.method
|
||||
redis_key = f'{key_prefix}:{scope}:{method}:{route}:{ident}'
|
||||
|
||||
logger.debug(f'RateLimit check key={redis_key} limit={limit} window={window_seconds}')
|
||||
|
||||
try:
|
||||
redis: Redis = get_redis(request)
|
||||
|
||||
result = await redis.eval(
|
||||
_LUA_INCR_EXPIRE_TTL,
|
||||
1,
|
||||
redis_key,
|
||||
str(window_seconds),
|
||||
)
|
||||
|
||||
count = int(result[0])
|
||||
ttl_raw = int(result[1]) if result and len(result) > 1 else window_seconds
|
||||
ttl = window_seconds if ttl_raw < 0 else ttl_raw
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'RateLimit redis failure key={redis_key} error={str(e)}')
|
||||
|
||||
if fail_open:
|
||||
logger.warning(f'RateLimit fail-open activated key={redis_key}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
raise ApplicationException(503, 'Rate limiter unavailable')
|
||||
|
||||
if count > limit:
|
||||
retry_after = max(ttl, 0)
|
||||
logger.warning(f'RateLimit exceeded key={redis_key} count={count} limit={limit} retry_after={retry_after}')
|
||||
raise ApplicationException(
|
||||
status_code=429,
|
||||
message='Too Many Requests',
|
||||
headers={'Retry-After': str(retry_after)},
|
||||
)
|
||||
|
||||
logger.debug(f'RateLimit passed key={redis_key} count={count}')
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
11
src/presentation/dependencies/__init__.py
Normal file
11
src/presentation/dependencies/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from src.presentation.dependencies.commands import (
|
||||
get_user_registration_complete_command,
|
||||
get_user_login_start_command,
|
||||
get_user_login_complete_command,
|
||||
get_user_logout_command,
|
||||
get_user_registration_start_command,
|
||||
get_jwt_refresh_command
|
||||
)
|
||||
from src.presentation.dependencies.security import get_jwt_service, get_jwt_service
|
||||
from src.presentation.dependencies.cache import get_redis
|
||||
from src.presentation.dependencies.queue_messanger import get_rabbit
|
||||
12
src/presentation/dependencies/cache.py
Normal file
12
src/presentation/dependencies/cache.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi import Depends, Request
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.infrastructure.cache import KeydbCache
|
||||
|
||||
|
||||
def get_redis(request: Request) -> Redis:
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
def get_cache(redis_client: Redis = Depends(get_redis)) -> ICache:
|
||||
return KeydbCache(redis_client)
|
||||
98
src/presentation/dependencies/commands.py
Normal file
98
src/presentation/dependencies/commands.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands import (
|
||||
UserRegistrationCompleteCommand,
|
||||
JwtRefreshCommand,
|
||||
UserRegistrationStartCommand,
|
||||
UserLogoutCommand,
|
||||
UserLoginCompleteCommand,
|
||||
UserLoginStartCommand
|
||||
)
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger, IQueueMessanger
|
||||
from src.application.contracts import ICache
|
||||
from src.presentation.dependencies.queue_messanger import get_rabbit
|
||||
from src.presentation.dependencies.cache import get_cache
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
|
||||
from src.presentation.dependencies.unit_of_work import get_unit_of_work
|
||||
|
||||
|
||||
def get_user_registration_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
) -> UserRegistrationStartCommand:
|
||||
return UserRegistrationStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
hash_service=hash_service,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
)
|
||||
|
||||
|
||||
def get_user_registration_complete_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
) -> UserRegistrationCompleteCommand:
|
||||
return UserRegistrationCompleteCommand(
|
||||
unit_of_work=uow,
|
||||
logger=logger,
|
||||
hash_service=hash_service,
|
||||
jwt_service=jwt_service,
|
||||
cache=cache
|
||||
)
|
||||
|
||||
def get_user_login_start_command(
|
||||
logger: ILogger = Depends(get_logger),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
unit_of_work: IUnitOfWork = Depends(get_unit_of_work),
|
||||
messanger: IQueueMessanger = Depends(get_rabbit),
|
||||
) -> UserLoginStartCommand:
|
||||
return UserLoginStartCommand(
|
||||
logger=logger,
|
||||
unit_of_work=unit_of_work,
|
||||
hash_service=hash_service,
|
||||
cache=cache,
|
||||
messanger=messanger,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_user_login_complete_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
) -> UserLoginCompleteCommand:
|
||||
return UserLoginCompleteCommand(
|
||||
unit_of_work=uow,
|
||||
logger=logger,
|
||||
hash_service=hash_service,
|
||||
jwt_service=jwt_service,
|
||||
cache=cache
|
||||
)
|
||||
|
||||
|
||||
def get_user_logout_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> UserLogoutCommand:
|
||||
return UserLogoutCommand(unit_of_work=uow, logger=logger, jwt_service=jwt_service)
|
||||
|
||||
|
||||
def get_jwt_refresh_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> JwtRefreshCommand:
|
||||
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)
|
||||
7
src/presentation/dependencies/logger.py
Normal file
7
src/presentation/dependencies/logger.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
@lru_cache
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
8
src/presentation/dependencies/queue_messanger.py
Normal file
8
src/presentation/dependencies/queue_messanger.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from functools import lru_cache
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.messanger import RabbitClient
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_rabbit() -> IQueueMessanger:
|
||||
return RabbitClient()
|
||||
25
src/presentation/dependencies/security.py
Normal file
25
src/presentation/dependencies/security.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from functools import lru_cache
|
||||
from fastapi import Depends
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
||||
from src.infrastructure.security import HashService, JwtService
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _hash_service(logger: ILogger) -> IHashService:
|
||||
return HashService(logger=logger)
|
||||
|
||||
|
||||
def get_hash_service(logger: ILogger = Depends(get_logger)) -> IHashService:
|
||||
return _hash_service(logger)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _jwt_service(logger: ILogger) -> IJwtService:
|
||||
key_store = JwtKeyStore.get_instance()
|
||||
return JwtService(logger=logger, key_store=key_store)
|
||||
|
||||
|
||||
def get_jwt_service(logger: ILogger = Depends(get_logger)) -> IJwtService:
|
||||
return _jwt_service(logger)
|
||||
10
src/presentation/dependencies/unit_of_work.py
Normal file
10
src/presentation/dependencies/unit_of_work.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database import UnitOfWork
|
||||
from src.infrastructure.database.context import async_session_maker
|
||||
from src.infrastructure.logger import get_logger
|
||||
|
||||
|
||||
def get_unit_of_work(logger: ILogger = Depends(get_logger)) -> IUnitOfWork:
|
||||
return UnitOfWork(session_factory=async_session_maker, logger=logger)
|
||||
2
src/presentation/handlers/__init__.py
Normal file
2
src/presentation/handlers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
|
||||
from src.presentation.handlers.application_handler import application_exception_handler
|
||||
17
src/presentation/handlers/application_handler.py
Normal file
17
src/presentation/handlers/application_handler.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
|
||||
detail = exc.message
|
||||
if 500 <= exc.status_code:
|
||||
detail = "Internal Server Error"
|
||||
|
||||
return ORJSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": detail},
|
||||
headers=dict(exc.headers) if exc.headers else None,
|
||||
)
|
||||
|
||||
|
||||
12
src/presentation/handlers/unhandled_handler.py
Normal file
12
src/presentation/handlers/unhandled_handler.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from starlette import status
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
|
||||
async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJSONResponse:
|
||||
logger.exception(f'Unhandled exception: {type(exc).__name__}')
|
||||
return ORJSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'detail': 'Internal Server Error'},
|
||||
)
|
||||
2
src/presentation/middleware/__init__.py
Normal file
2
src/presentation/middleware/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.middleware.trace_id import TraceIDMiddleware
|
||||
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware
|
||||
51
src/presentation/middleware/security_headers.py
Normal file
51
src/presentation/middleware/security_headers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
*,
|
||||
hsts: bool = True,
|
||||
hsts_max_age: int = 31536000, # 1 год
|
||||
hsts_include_subdomains: bool = True,
|
||||
hsts_preload: bool = False,
|
||||
frame_options: str = 'DENY', # или 'SAMEORIGIN'
|
||||
referrer_policy: str = 'strict-origin-when-cross-origin',
|
||||
content_security_policy: str | None = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.hsts = hsts
|
||||
self.hsts_max_age = hsts_max_age
|
||||
self.hsts_include_subdomains = hsts_include_subdomains
|
||||
self.hsts_preload = hsts_preload
|
||||
self.frame_options = frame_options
|
||||
self.referrer_policy = referrer_policy
|
||||
self.csp = content_security_policy
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response: Response = await call_next(request)
|
||||
|
||||
if request.url.path in ('/docs', '/redoc', '/openapi.json'):
|
||||
return response
|
||||
|
||||
if self.hsts and request.url.scheme == 'https':
|
||||
hsts = f'max-age={self.hsts_max_age}'
|
||||
if self.hsts_include_subdomains:
|
||||
hsts += '; includeSubDomains'
|
||||
if self.hsts_preload:
|
||||
hsts += '; preload'
|
||||
response.headers['Strict-Transport-Security'] = hsts
|
||||
|
||||
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||||
|
||||
response.headers['X-Frame-Options'] = self.frame_options
|
||||
|
||||
response.headers['Referrer-Policy'] = self.referrer_policy
|
||||
|
||||
if self.csp:
|
||||
response.headers['Content-Security-Policy'] = self.csp
|
||||
|
||||
return response
|
||||
69
src/presentation/middleware/trace_id.py
Normal file
69
src/presentation/middleware/trace_id.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from contextvars import Token
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class TraceIDMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
logger: ILogger,
|
||||
response_header_name: str = "X-Trace-ID",
|
||||
attach_response_header: bool = True,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.response_header_name = response_header_name
|
||||
self.attach_response_header = attach_response_header
|
||||
|
||||
def _is_excluded(self, path: str) -> bool:
|
||||
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope)
|
||||
|
||||
if self._is_excluded(request.url.path):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
|
||||
if not trace_id:
|
||||
trace_id = str(ULID())
|
||||
|
||||
request.state.trace_id = trace_id
|
||||
|
||||
token: Token = trace_id_var.set(trace_id)
|
||||
|
||||
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
|
||||
|
||||
status_code_holder: dict[str, Optional[int]] = {"status": None}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
status_code_holder["status"] = int(message["status"])
|
||||
|
||||
if self.attach_response_header:
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
status = status_code_holder["status"]
|
||||
status_part = f"{status}" if status is not None else "unknown"
|
||||
self.logger.debug(
|
||||
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
|
||||
)
|
||||
trace_id_var.reset(token)
|
||||
9
src/presentation/routing/__init__.py
Normal file
9
src/presentation/routing/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
from src.presentation.routing.auth import auth_router
|
||||
from src.presentation.routing.csrf import csrf_router
|
||||
from src.presentation.routing.jwt import jwt_router
|
||||
|
||||
v1_router = APIRouter(prefix='/v1')
|
||||
v1_router.include_router(auth_router)
|
||||
v1_router.include_router(csrf_router)
|
||||
v1_router.include_router(jwt_router)
|
||||
222
src/presentation/routing/auth.py
Normal file
222
src/presentation/routing/auth.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from fastapi import APIRouter, Depends, status, Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from ulid import ULID
|
||||
from src.application.commands import (
|
||||
UserLogoutCommand,
|
||||
UserRegistrationStartCommand,
|
||||
UserLoginStartCommand,
|
||||
UserRegistrationCompleteCommand,
|
||||
UserLoginCompleteCommand
|
||||
)
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.dto import UserLoginDto
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.decorators import rate_limit, email_rl_key
|
||||
from src.presentation.dependencies import (
|
||||
get_user_registration_complete_command,
|
||||
get_user_logout_command,
|
||||
get_user_registration_start_command,
|
||||
get_user_login_start_command,
|
||||
get_user_login_complete_command
|
||||
)
|
||||
from src.presentation.schemas import UserLogin, RegistrationStart, RegistrationComplete, LoginStart
|
||||
#from src.presentation.decorators import csrf_protect
|
||||
|
||||
|
||||
auth_router = APIRouter(prefix='/auth', tags=['auth'])
|
||||
|
||||
@auth_router.post(
|
||||
path='/registration/start',
|
||||
response_class=ORJSONResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
|
||||
async def registration_start(
|
||||
request: Request,
|
||||
body: RegistrationStart,
|
||||
command: UserRegistrationStartCommand = Depends(get_user_registration_start_command),
|
||||
):
|
||||
result = await command(body.email)
|
||||
|
||||
return {'success': result}
|
||||
|
||||
@auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED)
|
||||
@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
async def registration(
|
||||
request: Request,
|
||||
user: RegistrationComplete,
|
||||
command: UserRegistrationCompleteCommand = Depends(get_user_registration_complete_command),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
):
|
||||
device_id = request.cookies.get('device_id')
|
||||
|
||||
if not device_id:
|
||||
device_id = str(ULID())
|
||||
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
ip = xff.split(',')[0].strip() if xff else (request.client.host if request.client else None)
|
||||
user_agent = request.headers.get('user-agent')
|
||||
|
||||
created = await command(
|
||||
email=str(user.email),
|
||||
password=user.password,
|
||||
device_id=device_id,
|
||||
code=user.code,
|
||||
user_agent=user_agent,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
logger.info(f'Registration completed for user_id={created.id}')
|
||||
|
||||
response = ORJSONResponse(content={'id': created.id, 'email': created.email})
|
||||
|
||||
response.set_cookie(
|
||||
key='device_id',
|
||||
value=device_id,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=60 * 60 * 24 * 365 * 5
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key='access_token',
|
||||
value=created.access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
response.set_cookie(
|
||||
key='refresh_token',
|
||||
value=created.refresh_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
return response
|
||||
|
||||
@auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
|
||||
async def login_start(
|
||||
request: Request,
|
||||
body: LoginStart,
|
||||
command: UserLoginStartCommand = Depends(get_user_login_start_command),
|
||||
):
|
||||
result = await command(body.email)
|
||||
|
||||
return {'success': result}
|
||||
|
||||
@auth_router.post(path='/login/compete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
async def login(
|
||||
request: Request,
|
||||
user: UserLogin,
|
||||
command: UserLoginCompleteCommand = Depends(get_user_login_complete_command),
|
||||
logger: ILogger = Depends(get_logger)
|
||||
):
|
||||
device_id = request.cookies.get('device_id')
|
||||
|
||||
if not device_id:
|
||||
device_id = str(ULID())
|
||||
|
||||
xff = request.headers.get('x-forwarded-for')
|
||||
ip = xff.split(',')[0].strip() if xff else (request.client.host if request.client else None)
|
||||
user_agent = request.headers.get('user-agent')
|
||||
|
||||
dto: UserLoginDto = await command(
|
||||
email=str(user.email),
|
||||
password=user.password,
|
||||
code=user.code,
|
||||
device_id=device_id,
|
||||
user_agent=user_agent,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
logger.info(f'Login completed for user_id={dto.id}')
|
||||
|
||||
response = ORJSONResponse(
|
||||
content={
|
||||
'id': dto.id,
|
||||
'email': dto.email,
|
||||
'first_name': dto.first_name,
|
||||
'middle_name': dto.middle_name,
|
||||
'last_name': dto.last_name,
|
||||
'birth_date': dto.birth_date.isoformat() if dto.birth_date else None,
|
||||
'crypto_wallet': dto.crypto_wallet,
|
||||
'phone': dto.phone,
|
||||
'bik': dto.bik,
|
||||
'account_number': dto.account_number,
|
||||
'card_number': dto.card_number,
|
||||
'inn': dto.inn,
|
||||
'kyc_verified': dto.kyc_verified,
|
||||
'kyc_verified_at': dto.kyc_verified_at,
|
||||
'created_at': dto.created_at.isoformat() if dto.created_at else None,
|
||||
'updated_at': dto.updated_at.isoformat() if dto.updated_at else None,
|
||||
}
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key='device_id',
|
||||
value=device_id,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=60 * 60 * 24 * 365 * 5
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key='access_token',
|
||||
value=dto.access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key='refresh_token',
|
||||
value=dto.refresh_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
async def logout_current(
|
||||
request: Request,
|
||||
command: UserLogoutCommand = Depends(get_user_logout_command),
|
||||
):
|
||||
refresh_token = request.cookies.get('refresh_token')
|
||||
|
||||
await command(refresh_token=refresh_token)
|
||||
|
||||
response = ORJSONResponse({'ok': True})
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
return response
|
||||
|
||||
|
||||
|
||||
# @auth_router.get(path='/ping')
|
||||
# @csrf_protect()
|
||||
# async def ping(request: Request):
|
||||
# return ORJSONResponse(
|
||||
# content={
|
||||
# 'status': 'pong'
|
||||
# }
|
||||
# )
|
||||
37
src/presentation/routing/csrf.py
Normal file
37
src/presentation/routing/csrf.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from src.infrastructure.security import CsrfService
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.decorators import rate_limit
|
||||
|
||||
csrf_router = APIRouter(prefix='/csrf', tags=['csrf'])
|
||||
|
||||
|
||||
@csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
async def issue_csrf_token():
|
||||
csrf = CsrfService()
|
||||
|
||||
token = csrf.issue()
|
||||
|
||||
response = ORJSONResponse(
|
||||
content={
|
||||
'token': token,
|
||||
'header_name': csrf.header_name,
|
||||
}
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
key=csrf.cookie_name,
|
||||
value=token,
|
||||
secure=settings.CSRF_COOKIE_SECURE,
|
||||
httponly=settings.CSRF_COOKIE_HTTPONLY,
|
||||
samesite=settings.CSRF_COOKIE_SAMESITE,
|
||||
path=settings.CSRF_COOKIE_PATH,
|
||||
domain=settings.CSRF_COOKIE_DOMAIN,
|
||||
max_age=csrf.ttl_seconds,
|
||||
)
|
||||
|
||||
return response
|
||||
64
src/presentation/routing/jwt.py
Normal file
64
src/presentation/routing/jwt.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from src.application.commands import JwtRefreshCommand
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.decorators import rate_limit
|
||||
from src.presentation.dependencies import get_jwt_refresh_command
|
||||
|
||||
|
||||
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
|
||||
|
||||
|
||||
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
async def refresh_tokens(
|
||||
request: Request,
|
||||
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
|
||||
):
|
||||
refresh_token = request.cookies.get('refresh_token')
|
||||
|
||||
if not refresh_token:
|
||||
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
return response
|
||||
|
||||
ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('user-agent')
|
||||
|
||||
try:
|
||||
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
|
||||
except ApplicationException:
|
||||
response = ORJSONResponse({'result': False}, status_code=401)
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
return response
|
||||
|
||||
response = ORJSONResponse({'result': True})
|
||||
|
||||
response.set_cookie(
|
||||
key='access_token',
|
||||
value=access,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
response.set_cookie(
|
||||
key='refresh_token',
|
||||
value=refresh,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
return response
|
||||
|
||||
# Usage
|
||||
# @jwt_router.get("/test")
|
||||
# async def profile(auth: AuthContext = Depends(require_access_token)):
|
||||
# return 'ok'
|
||||
1
src/presentation/schemas/__init__.py
Normal file
1
src/presentation/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.presentation.schemas.user import RegistrationStart, RegistrationComplete, UserLogin, LoginStart
|
||||
89
src/presentation/schemas/user.py
Normal file
89
src/presentation/schemas/user.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
import re
|
||||
from typing import ClassVar
|
||||
from pydantic import BaseModel, EmailStr, Field, ValidationError, field_validator, model_validator
|
||||
|
||||
|
||||
|
||||
class EmailNoSubaddressing(BaseModel):
|
||||
email: EmailStr = Field(title='Email', description='Email without subaddressing')
|
||||
|
||||
@field_validator('email')
|
||||
@classmethod
|
||||
def validate_and_normalize_email(cls, v: EmailStr) -> str:
|
||||
email = str(v).strip().lower()
|
||||
local, _, domain = email.partition('@')
|
||||
if not local or not domain:
|
||||
raise ValueError('Invalid email')
|
||||
if '+' in local:
|
||||
raise ValueError('Email subaddressing is not allowed')
|
||||
if any(ord(ch) > 127 for ch in local):
|
||||
raise ValueError('Email must be ASCII')
|
||||
if local.startswith('.') or local.endswith('.') or '..' in local:
|
||||
raise ValueError('Invalid email local part')
|
||||
if not re.fullmatch(r'[A-Za-z0-9._-]+', local):
|
||||
raise ValueError('Email contains запрещенные символы')
|
||||
|
||||
return email
|
||||
|
||||
|
||||
class RegistrationStart(EmailNoSubaddressing):
|
||||
pass
|
||||
|
||||
class LoginStart(EmailNoSubaddressing):
|
||||
pass
|
||||
|
||||
|
||||
class RegistrationComplete(EmailNoSubaddressing):
|
||||
password: str = Field(min_length=12)
|
||||
confirm_password: str = Field(min_length=12)
|
||||
code: str = Field(
|
||||
min_length=6,
|
||||
max_length=6,
|
||||
pattern=r"^\d{6}$",
|
||||
)
|
||||
|
||||
_allowed_specials: ClassVar[str] = '!@#$%^&*()_+-=.,:;?/[]{}<>'
|
||||
|
||||
@field_validator('password')
|
||||
@classmethod
|
||||
def validate_password_policy(cls, v: str) -> str:
|
||||
if len(v) < 12:
|
||||
raise ValueError('Password must be at least 12 characters long')
|
||||
if not any(c.islower() for c in v):
|
||||
raise ValueError('Password must contain at least one lowercase letter')
|
||||
if not any(c.isupper() for c in v):
|
||||
raise ValueError('Password must contain at least one uppercase letter')
|
||||
if not any(c.isdigit() for c in v):
|
||||
raise ValueError('Password must contain at least one digit')
|
||||
if not any(c in cls._allowed_specials for c in v):
|
||||
raise ValueError(
|
||||
'Password must contain at least one special character '
|
||||
f'from: {cls._allowed_specials}'
|
||||
)
|
||||
if any(c.isspace() for c in v):
|
||||
raise ValueError('Password must not contain whitespace')
|
||||
return v
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_password_confirmation(self) -> 'RegistrationComplete':
|
||||
if self.password != self.confirm_password:
|
||||
raise ValidationError.from_exception_data(
|
||||
title='Passwords do not match',
|
||||
line_errors=[{
|
||||
'type': 'value_error',
|
||||
'loc': ('confirm_password',),
|
||||
'msg': 'Passwords do not match',
|
||||
'input': self.confirm_password,
|
||||
}],
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class UserLogin(EmailNoSubaddressing):
|
||||
password: str = Field(min_length=12)
|
||||
code: str = Field(
|
||||
min_length=6,
|
||||
max_length=6,
|
||||
pattern=r"^\d{6}$",
|
||||
)
|
||||
Reference in New Issue
Block a user