feat(account): GET /me user endpoint only, disable cache and extra routers

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2026-05-12 20:44:35 +03:00
commit d94dd31439
107 changed files with 5083 additions and 0 deletions

2
src/infrastructure/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,2 @@
from src.infrastructure.cache.client import create_redis_client
from src.infrastructure.cache.keydb_client import KeydbCache

16
src/infrastructure/cache/client.py vendored Normal file
View File

@@ -0,0 +1,16 @@
import redis.asyncio as redis
from redis.asyncio.client import Redis
from src.infrastructure.config import settings
def create_redis_client() -> Redis:
return redis.from_url(
settings.REDIS_URL,
max_connections=50,
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
retry_on_timeout=True,
socket_keepalive=True,
)

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import orjson
from redis.asyncio.client import Redis
from src.application.contracts import ICache
from src.application.domain.entities.user import UserEntity
class KeydbCache(ICache):
USER_PREFIX = 'user:me'
def __init__(self, redis_client: Redis):
self._r = redis_client
async def set(self, key: str, value: str, ttl: int) -> bool:
return bool(await self._r.set(key, value, ex=ttl))
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
return bool(await self._r.set(key, value, ex=ttl, nx=True))
async def get(self, key: str) -> str | None:
return await self._r.get(key)
async def delete(self, key: str) -> bool:
return (await self._r.delete(key)) > 0
async def get_user(self, user_id: str) -> dict | None:
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
if raw is None:
return None
return orjson.loads(raw)
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
data = orjson.dumps({
'id': user.id,
'email': user.email,
'first_name': user.first_name,
'middle_name': user.middle_name,
'last_name': user.last_name,
'birth_date': str(user.birth_date) if user.birth_date else None,
'crypto_wallet': user.crypto_wallet,
'phone': user.phone,
'bik': user.bik,
'account_number': user.account_number,
'card_number': user.card_number,
'inn': user.inn,
'kyc_verified': user.kyc_verified,
'is_deleted': user.is_deleted,
'created_at': user.created_at.isoformat() if user.created_at else None,
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
})
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)

View File

@@ -0,0 +1 @@
from src.infrastructure.config.settings import settings

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
from functools import lru_cache
from typing import List, Literal
import os
from dotenv import load_dotenv, find_dotenv
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
env_file = find_dotenv(".env")
if env_file:
load_dotenv(env_file)
class Settings(BaseSettings):
VAULT_ADDR: str = Field(default="http://localhost:8200")
VAULT_TOKEN: str = Field(..., description="Vault token is required")
VAULT_MOUNT_POINT: str = Field(default="secrets")
VAULT_JWT_KID_PATH: str = "jwt/kid"
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
JWT_KEYS_REFRESH_SECONDS: int = 3600
DATABASE_HOST: str
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
DATABASE_NAME: str
DATABASE_USER: str
DATABASE_PASSWORD: str
DATABASE_POOL_SIZE: int = 10
DATABASE_MAX_OVERFLOW: int = 20
DATABASE_POOL_TIMEOUT: int = 30
DATABASE_POOL_RECYCLE: int = 3600
DATABASE_ECHO: bool = False
CSRF_SECRET_KEY: str = Field(
default="change-me-change-me-change-me-change-me",
min_length=32,
)
CSRF_COOKIE_SECURE: bool = False
CSRF_COOKIE_HTTPONLY: bool = True
CSRF_COOKIE_SAMESITE: Literal["Lax", "Strict", "None"] = "Lax"
CSRF_COOKIE_PATH: str = "/"
CSRF_COOKIE_DOMAIN: str | None = None
DOCS_USERNAME: str = "admin"
DOCS_PASSWORD: str = "admin"
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
JWT_ISSUER: str | None = None
JWT_AUDIENCE: str | None = None
JWT_ALGORITHM: str = "RS256"
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
REDIS_PASSWORD: str | None = None
REDIS_DB: int = 0
RABBIT_HOST: str = "localhost"
RABBIT_PORT: int = 5672
RABBIT_USER: str = "guest"
RABBIT_PASSWORD: str = "guest"
RABBIT_VHOST: str = "/"
RABBIT_PUBLISH_PERSIST: bool = True
RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code"
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=True,
extra="ignore",
)
@model_validator(mode="before")
@classmethod
def load_from_vault(cls, data: dict):
addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200"
token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN")
mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets"
if not token:
raise RuntimeError("VAULT_TOKEN is required")
client = create_hvac_client(url=addr, token=token, timeout=5)
def safe_read(path: str) -> dict:
try:
return read_kv2_secret(client=client, mount_point=mount, path=path)
except Exception:
return {}
database = safe_read("database")
rabbitmq = safe_read("rabbitmq")
csrf = safe_read("csrf")
if database:
required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"]
missing = [k for k in required if k not in database]
if missing:
raise RuntimeError(f"Vault database secret missing keys {missing}")
data["DATABASE_HOST"] = database["HOST"]
data["DATABASE_PORT"] = database["PORT"]
data["DATABASE_NAME"] = database["NAME"]
data["DATABASE_USER"] = database["USER"]
data["DATABASE_PASSWORD"] = database["PASSWORD"]
if rabbitmq:
data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST"))
data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT"))
data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER"))
data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD"))
data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST"))
if csrf:
data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY"))
return data
@property
def DATABASE_URL(self) -> str:
return (
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
)
@property
def REDIS_URL(self) -> str:
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
@property
def RABBIT_URL(self) -> str:
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
@property
def EXCLUDED_PATHS(self) -> List[str]:
return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"]
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@@ -0,0 +1 @@
from src.infrastructure.context_vars.trace_id import trace_id_var

View File

@@ -0,0 +1,4 @@
from contextvars import ContextVar
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')

View File

@@ -0,0 +1 @@
from src.infrastructure.database.unit_of_work import UnitOfWork

View File

@@ -0,0 +1,22 @@
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.ext.asyncio.engine import create_async_engine
from sqlalchemy.ext.asyncio.session import AsyncSession
from typing import AsyncGenerator
from src.infrastructure.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
pool_recycle=settings.DATABASE_POOL_RECYCLE,
echo=settings.DATABASE_ECHO
)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
yield session

View File

@@ -0,0 +1 @@
from src.infrastructure.database.decorators.transactional import transactional

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from functools import wraps
from typing import Callable, Awaitable, TypeVar, ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
@wraps(method)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
async with self._unit_of_work:
return await method(self, *args, **kwargs)
return wrapper

View File

@@ -0,0 +1,6 @@
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.user import UserModel
from src.infrastructure.database.models.sessions import Session
__all__ = ['Base', 'UserModel', 'Session']

View File

@@ -0,0 +1,19 @@
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase
class Base(AsyncAttrs, DeclarativeBase):
__abstract__ = True
def __repr__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
for col in self.__table__.columns)
return f"<{class_name}({attributes})>"
def __str__(self) -> str:
class_name = self.__class__.__name__
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
for col in self.__table__.columns
if getattr(self, col.name) is not None)
return f"{class_name}({attributes})"

View File

@@ -0,0 +1,3 @@
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin

View File

@@ -0,0 +1,16 @@
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
class AuditTimestampsMixin:
created_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
)
updated_at: Mapped[DateTime] = mapped_column(
DateTime(timezone=True),
nullable=False,
server_default=func.now(),
onupdate=func.now(),
)

View File

@@ -0,0 +1,6 @@
from sqlalchemy import Boolean
from sqlalchemy.orm import Mapped, mapped_column
class SoftDeleteMixin:
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)

View File

@@ -0,0 +1,8 @@
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
class UlidPrimaryKeyMixin:
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))

View File

@@ -0,0 +1,50 @@
from datetime import datetime, timezone
from sqlalchemy import String, DateTime, ForeignKey, Index
from sqlalchemy.orm import Mapped, mapped_column
from ulid import ULID
from src.infrastructure.database.models import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
__tablename__ = "sessions"
sid: Mapped[str] = mapped_column(
String(26),
unique=True,
index=True,
nullable=False,
default=lambda: str(ULID()),
)
user_id: Mapped[str] = mapped_column(
String(26),
ForeignKey("users.id", ondelete="CASCADE"),
index=True,
nullable=False,
)
device_id: Mapped[str] = mapped_column(
String(26),
nullable=False,
index=True,
)
user_agent: Mapped[str | None] = mapped_column(String(500))
first_ip: Mapped[str | None] = mapped_column(String(64))
last_ip: Mapped[str | None] = mapped_column(String(64))
last_seen_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(timezone.utc),
)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from sqlalchemy import Boolean,Date,String,DateTime
from sqlalchemy.orm import Mapped,mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
__tablename__ = 'users'
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View File

@@ -0,0 +1,2 @@
from src.infrastructure.database.repositories.user_repository import UserRepository
from src.infrastructure.database.repositories.session_repository import SessionRepository

View File

@@ -0,0 +1,198 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.application.contracts import ILogger
from src.application.domain.entities import SessionEntity
from src.application.abstractions.repositories import ISessionRepository
from src.infrastructure.database.models import Session
class SessionRepository(ISessionRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
res = await self._session.execute(select(Session).where(Session.sid == sid))
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
return None
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def upsert_by_device(
self,
*,
user_id: str,
device_id: str,
sid: str,
refresh_jti_hash: str,
refresh_expires_at: datetime,
user_agent: str | None,
ip: str | None,
now: datetime,
) -> SessionEntity:
res = await self._session.execute(
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
)
m = res.scalar_one_or_none()
if m is None:
m = Session(
sid=sid,
user_id=user_id,
device_id=device_id,
revoked_at=None,
last_seen_at=now,
refresh_jti_hash=refresh_jti_hash,
refresh_expires_at=refresh_expires_at,
user_agent=user_agent,
first_ip=ip,
last_ip=ip,
)
self._session.add(m)
await self._session.flush()
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
else:
m.sid = sid
m.revoked_at = None
m.last_seen_at = now
m.refresh_jti_hash = refresh_jti_hash
m.refresh_expires_at = refresh_expires_at
m.user_agent = user_agent
m.last_ip = ip
await self._session.flush()
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
return SessionEntity(
sid=m.sid,
user_id=m.user_id,
device_id=m.device_id,
revoked_at=m.revoked_at,
last_seen_at=m.last_seen_at,
refresh_jti_hash=m.refresh_jti_hash,
refresh_expires_at=m.refresh_expires_at,
user_agent=m.user_agent,
first_ip=m.first_ip,
last_ip=m.last_ip,
)
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
# Интерфейс требует None -> просто делаем update и flush
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(revoked_at=now)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh(
self,
sid: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> None:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
values = {'last_seen_at': now}
if ip is not None:
values['last_ip'] = ip
await self._session.execute(
update(Session)
.where(Session.sid == sid, Session.revoked_at.is_(None))
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
async def rotate_refresh_if_match(
self,
*,
sid: str,
old_jti_hash: str,
new_jti_hash: str,
new_refresh_expires_at: datetime,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> bool:
values = {
'refresh_jti_hash': new_jti_hash,
'refresh_expires_at': new_refresh_expires_at,
'last_seen_at': now,
'user_agent': user_agent,
}
if ip is not None:
values['last_ip'] = ip
res = await self._session.execute(
update(Session)
.where(
Session.sid == sid,
Session.revoked_at.is_(None),
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
)
.values(**values)
.execution_options(synchronize_session='fetch')
)
await self._session.flush()
return (res.rowcount or 0) > 0

View File

@@ -0,0 +1,118 @@
from __future__ import annotations
from fastapi import status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException
from src.application.abstractions.repositories import IUserRepository
from src.application.domain.entities import UserEntity
from src.infrastructure.database.models import UserModel
class UserRepository(IUserRepository):
def __init__(self, session: AsyncSession, logger: ILogger):
self._session = session
self._logger = logger
async def _get_active_user(self, user_id: str) -> UserModel:
stmt = (
select(UserModel)
.where(
UserModel.id == user_id,
UserModel.is_deleted.is_(False),
)
)
result = await self._session.execute(stmt)
user: UserModel | None = result.scalar_one_or_none()
if user is None:
self._logger.warning(f'User not found with user_id {user_id}')
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found')
return user
@staticmethod
def _to_entity(user: UserModel) -> UserEntity:
return UserEntity(
id=user.id,
email=user.email,
password_hash=None,
first_name=user.first_name,
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet,
phone=user.phone,
bik=user.bik,
account_number=user.account_number,
card_number=user.card_number,
inn=user.inn,
kyc_verified_at=user.kyc_verified_at,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted,
created_at=user.created_at,
updated_at=user.updated_at,
)
async def get_user_by_id(self, user_id: str) -> UserEntity:
try:
user = await self._get_active_user(user_id)
return self._to_entity(user)
except ApplicationException:
raise
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
async def _update_field(self, user_id: str, **fields: object) -> UserEntity:
try:
user = await self._get_active_user(user_id)
for key, value in fields.items():
setattr(user, key, value)
await self._session.flush()
await self._session.refresh(user)
return self._to_entity(user)
except ApplicationException:
raise
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
async def set_phone(self, user_id: str, phone: str) -> UserEntity:
return await self._update_field(user_id, phone=phone)
async def set_bank_details(self, user_id: str, **fields: str) -> UserEntity:
return await self._update_field(user_id, **fields)
async def set_crypto_wallet(self, user_id: str, wallet_address: str) -> UserEntity:
return await self._update_field(user_id, crypto_wallet=wallet_address)
async def get_password_hash(self, user_id: str) -> str:
try:
user = await self._get_active_user(user_id)
return user.password_hash
except ApplicationException:
raise
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
async def set_password(self, user_id: str, password_hash: str) -> UserEntity:
return await self._update_field(user_id, password_hash=password_hash)
async def set_email(self, user_id: str, email: str) -> UserEntity:
return await self._update_field(user_id, email=email)
async def email_exists(self, email: str) -> bool:
try:
stmt = (
select(UserModel)
.where(
UserModel.email == email,
UserModel.is_deleted.is_(False),
)
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none() is not None
except SQLAlchemyError as exception:
self._logger.exception(str(exception))
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')

View File

@@ -0,0 +1,42 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
from src.application.contracts import ILogger
from src.infrastructure.database.repositories import UserRepository, SessionRepository
class UnitOfWork(IUnitOfWork):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
self.session_factory = session_factory
self._session: AsyncSession = None
self._user_repository: IUserRepository = None
self._session_repository: ISessionRepository = None
self._logger: ILogger = logger
async def __aenter__(self):
self._session = self.session_factory()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._logger.error(str(exc_val))
await self._session.rollback()
self._logger.error(f'Rollback: str{exc_val})')
else:
await self._session.flush()
await self._session.commit()
self._logger.debug('Commit')
await self._session.close()
@property
def user_repository(self) -> IUserRepository:
if self._user_repository is None:
self._user_repository = UserRepository(session=self._session, logger=self._logger)
return self._user_repository
@property
def session_repository(self) -> ISessionRepository:
if self._session_repository is None:
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
return self._session_repository

View File

@@ -0,0 +1,28 @@
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat
from src.application.domain.enums import LogLevel
from src.infrastructure.config.settings import settings
from src.infrastructure.logger.logger import Logger
log_levels = {
'DEBUG': LogLevel.DEBUG,
'INFO': LogLevel.INFO,
'WARNING': LogLevel.WARNING,
'ERROR': LogLevel.ERROR,
'CRITICAL': LogLevel.CRITICAL,
'EXCEPTION': LogLevel.EXCEPTION,
}
log_formats = {
'JSON': LogFormat.JSON,
'TEXT': LogFormat.TEXT,
}
logger = Logger(
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
)
def get_logger() -> ILogger:
return logger

View File

@@ -0,0 +1,129 @@
import traceback
import inspect
import sys
import json
from datetime import datetime
from typing import Callable, Optional, Any
from ulid import ULID
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat, LogLevel
from src.infrastructure.context_vars import trace_id_var
class Logger(ILogger):
_instance = None
__default_format = LogFormat.JSON
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
if cls._instance is None:
cls._instance = super(Logger, cls).__new__(cls)
return cls._instance
def __init__(
self,
log_format: LogFormat = __default_format,
min_level: LogLevel = LogLevel.INFO,
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
instance_id: str = "N/A",
):
self.log_format = log_format
self.min_level = min_level
self.id_generator = id_generator
self.instance_id = instance_id
def set_instance_id(self, instance_id: str) -> None:
self.instance_id = instance_id
def get_instance_id(self) -> str:
return self.instance_id
def set_format(self, log_format: LogFormat) -> None:
if not isinstance(log_format, LogFormat):
raise ValueError("Log format must be an instance of LogFormat enum")
self.log_format = log_format
def set_min_level(self, level: LogLevel) -> None:
self.min_level = level
def new_trace_id(self) -> str:
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
trace_id_var.set(trace_id)
return trace_id
def set_trace_id(self, trace_id: str) -> None:
trace_id_var.set(trace_id)
def get_trace_id(self) -> str:
return trace_id_var.get()
def clear_trace_id(self) -> None:
trace_id_var.set("N/A")
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
current_frame = inspect.currentframe()
if (
current_frame
and current_frame.f_back
and current_frame.f_back.f_back
and current_frame.f_back.f_back.f_back
):
frame = current_frame.f_back.f_back.f_back
filename = frame.f_code.co_filename
line_number = frame.f_lineno
else:
filename = "unknown"
line_number = 0
log_data = {
"timestamp": datetime.now().isoformat(),
"level": level.name,
"instance_id": self.instance_id,
"file": filename,
"line": line_number,
"trace_id": trace_id_var.get(),
"message": message,
}
if level == LogLevel.EXCEPTION:
log_data["exception"] = traceback.format_exc()
return log_data
def _log(self, level: LogLevel, message: str) -> None:
if level >= self.min_level:
log_data = self._prepare_log_data(level, message)
if self.log_format == LogFormat.JSON:
log_message = json.dumps(log_data, ensure_ascii=False)
else:
log_message = (
f"{log_data['timestamp']} - {log_data['level']} - "
f"{log_data['instance_id']} - {log_data['trace_id']} - "
f"{log_data['file']}:{log_data['line']} - "
f"{log_data['message']}"
)
if "exception" in log_data:
log_message += f"\nTraceback:\n{log_data['exception']}"
self._write(log_message)
def _write(self, message: str) -> None:
sys.stdout.write(message + "\n")
def debug(self, message: str) -> None:
self._log(LogLevel.DEBUG, message)
def info(self, message: str) -> None:
self._log(LogLevel.INFO, message)
def warning(self, message: str) -> None:
self._log(LogLevel.WARNING, message)
def error(self, message: str) -> None:
self._log(LogLevel.ERROR, message)
def critical(self, message: str) -> None:
self._log(LogLevel.CRITICAL, message)
def exception(self, message: str) -> None:
self._log(LogLevel.EXCEPTION, message)

View File

@@ -0,0 +1 @@
from src.infrastructure.messanger.rabbit_client import RabbitClient

View File

@@ -0,0 +1,72 @@
from typing import Any, Mapping
from faststream.rabbit import RabbitBroker
from src.application.contracts import IQueueMessanger
from src.infrastructure.config import settings
class RabbitClient(IQueueMessanger):
def __init__(self) -> None:
self._broker = RabbitBroker(
settings.RABBIT_URL,
)
self._connected = False
async def connect(self) -> None:
if self._connected:
return
await self._broker.connect()
self._connected = True
async def close(self) -> None:
if not self._connected:
return
await self._broker.close()
self._connected = False
async def _ensure_connected(self) -> None:
if not self._connected:
await self.connect()
async def publish_to_queue(
self,
queue: str,
message: Any,
*,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
queue=queue,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)
async def publish(
self,
message: Any,
*,
exchange: str,
routing_key: str,
persist: bool | None = None,
headers: Mapping[str, Any] | None = None,
correlation_id: str | None = None,
message_id: str | None = None,
) -> None:
await self._ensure_connected()
await self._broker.publish(
message,
exchange=exchange,
routing_key=routing_key,
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
headers=headers,
correlation_id=correlation_id,
message_id=message_id,
)

View File

@@ -0,0 +1,3 @@
from src.infrastructure.security.jwt import JwtService
from src.infrastructure.security.csrf import CsrfService
from src.infrastructure.security.hash import HashService

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
import secrets
from typing import Any, Optional, Mapping
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
from src.application.contracts import ICsrfService
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
class CsrfService(ICsrfService):
COOKIE_NAME = 'csrf_token'
HEADER_NAME = 'X-CSRF-Token'
SALT = 'csrf'
TTL_SECONDS = 3600
def __init__(self) -> None:
self._serializer = URLSafeTimedSerializer(
secret_key=settings.CSRF_SECRET_KEY,
salt=self.SALT,
)
@property
def cookie_name(self) -> str:
return self.COOKIE_NAME
@property
def header_name(self) -> str:
return self.HEADER_NAME
@property
def ttl_seconds(self) -> int:
return self.TTL_SECONDS
def issue(self, subject: Optional[str] = None) -> str:
payload = {
'sub': subject,
'nonce': secrets.token_urlsafe(32),
}
return self._serializer.dumps(payload)
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
try:
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
except SignatureExpired:
raise ApplicationException(
status_code=403,
message='CSRF token expired',
)
except BadSignature:
raise ApplicationException(
status_code=403,
message='CSRF token invalid',
)
if expected_subject is not None and data.get('sub') != expected_subject:
raise ApplicationException(
status_code=403,
message='CSRF token subject mismatch',
)
return data
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
cookie_token = cookies.get(self.COOKIE_NAME)
header_token = headers.get(self.HEADER_NAME)
return cookie_token, header_token
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
if not cookie_token or not header_token:
raise ApplicationException(
status_code=403,
message='CSRF token missing',
)
if not secrets.compare_digest(cookie_token, header_token):
raise ApplicationException(
status_code=403,
message='CSRF token mismatch',
)
self.verify(cookie_token, expected_subject=expected_subject)

View File

@@ -0,0 +1,17 @@
import bcrypt
from src.application.contracts import IHashService, ILogger
class HashService(IHashService):
def __init__(self, logger: ILogger):
self._logger = logger
async def hash(self, value: str) -> str:
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
self._logger.info(f'Hash value {hashed_value.decode()}')
return hashed_value.decode()
async def verify(self, hashed_value: str, plain_value: str) -> bool:
self._logger.info(f'Hash value {hashed_value[:10]}')
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())

View File

@@ -0,0 +1,109 @@
from __future__ import annotations
from jose import jwt, ExpiredSignatureError, JWTError
from src.application.contracts import ILogger, IJwtService
from src.application.domain.dto import AccessTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.config.settings import settings
from src.infrastructure.vault import JwtKeyStore
class JwtService(IJwtService):
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
self._logger = logger
self._key_store = key_store
async def decode_access_token(self, token: str) -> AccessTokenPayload:
payload = await self._decode_and_verify(token)
if payload.get('type') != 'access':
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
raise ApplicationException(status_code=401, message='Invalid token type')
try:
return AccessTokenPayload(
sub=str(payload['sub']),
type='access',
sid=str(payload['sid']),
iat=int(payload['iat']),
nbf=int(payload['nbf']),
exp=int(payload['exp']),
iss=payload.get('iss'),
aud=payload.get('aud'),
)
except KeyError as exception:
self._logger.warning(f'Access token missing claim error={str(exception)}')
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
async def _decode_and_verify(self, token: str) -> dict:
kid: str | None = None
try:
header = jwt.get_unverified_header(token)
kid = header.get('kid')
if not kid:
self._logger.warning(f'JWT header missing kid header={header}')
raise ApplicationException(status_code=401, message='Missing token header: kid')
received_alg = header.get('alg')
if received_alg != settings.JWT_ALGORITHM:
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
raise ApplicationException(status_code=401, message='Invalid token algorithm')
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
await self._key_store.refresh()
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.warning(f'JWT unknown kid kid={kid}')
raise ApplicationException(status_code=401, message='Unknown token kid')
options = {
'verify_signature': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iat': True,
'verify_aud': bool(settings.JWT_AUDIENCE),
'verify_iss': bool(settings.JWT_ISSUER),
'require_exp': True,
'require_iat': True,
'require_nbf': True,
'require_sub': True,
'leeway': 10,
}
payload = jwt.decode(
token,
public_pem,
algorithms=[settings.JWT_ALGORITHM],
audience=settings.JWT_AUDIENCE or None,
issuer=settings.JWT_ISSUER or None,
options=options,
)
if 'sid' not in payload:
self._logger.warning(f'JWT missing sid claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: sid')
if 'type' not in payload:
self._logger.warning(f'JWT missing type claim kid={kid}')
raise ApplicationException(status_code=401, message='Missing token claim: type')
return payload
except ExpiredSignatureError as exception:
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Token expired')
except ApplicationException:
raise
except JWTError as exception:
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
raise ApplicationException(status_code=401, message='Invalid token')
except Exception as exception:
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
raise ApplicationException(status_code=500, message='JWT decode failed')

View File

@@ -0,0 +1 @@
from src.infrastructure.utils.instance_id import generate_instance_id

View File

@@ -0,0 +1,14 @@
from ulid import ULID
def generate_instance_id() -> str:
"""
Generate a process-wide instance id in ULID format.
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
"""
return str(ULID())

View File

@@ -0,0 +1,3 @@
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
from src.infrastructure.vault.keys import JwtKeyStore
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
class JwtKeyStore:
_instance: 'JwtKeyStore | None' = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
*,
vault_addr: str,
vault_token: str,
mount_point: str,
kid_path: str = 'jwt/kid',
kids_prefix: str = 'jwt/kids',
timeout_seconds: int = 5,
refresh_ttl_seconds: int = 60,
):
if getattr(self, '_initialized', False):
return
self._vault_addr = vault_addr
self._vault_token = vault_token
self._timeout = timeout_seconds
self._mount = mount_point
self._kid_path = kid_path
self._kids_prefix = kids_prefix
self._refresh_ttl_seconds = refresh_ttl_seconds
self._lock = asyncio.Lock()
self._keyset: JwtPublicKeySet | None = None
self._last_refresh_at: datetime | None = None
self._initialized = True
@classmethod
def get_instance(cls) -> 'JwtKeyStore':
if cls._instance is None:
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
return cls._instance
def _read_keyset_sync(self) -> JwtPublicKeySet:
client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout)
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
active_kid = kids.get('active')
previous_kid = kids.get('previous')
if not active_kid:
raise RuntimeError('Vault jwt/kid secret missing "active"')
active = self._read_public_key_sync(client, str(active_kid))
previous = None
if previous_kid and previous_kid != active_kid:
previous = self._read_public_key_sync(client, str(previous_kid))
return JwtPublicKeySet(active=active, previous=previous)
def _read_public_key_sync(self, client, kid: str) -> JwtPublicKey:
data = read_kv2_secret(
client=client,
mount_point=self._mount,
path=f'{self._kids_prefix}/{kid}',
)
pub = data.get('public_key')
if not pub:
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
return JwtPublicKey(kid=kid, public_key_pem=pub)
async def refresh(self) -> JwtPublicKeySet:
keyset = await asyncio.to_thread(self._read_keyset_sync)
async with self._lock:
self._keyset = keyset
self._last_refresh_at = datetime.now(timezone.utc)
return keyset
async def get_public_key_for_kid(self, kid: str) -> str | None:
ks = await self._get_or_refresh()
return ks.public_keys_by_kid().get(kid)
async def last_refresh_at(self) -> datetime | None:
async with self._lock:
return self._last_refresh_at
async def _get_or_refresh(self) -> JwtPublicKeySet:
async with self._lock:
ks = self._keyset
last = self._last_refresh_at
if ks is None:
return await self.refresh()
if last is None:
return await self.refresh()
age = (datetime.now(timezone.utc) - last).total_seconds()
if age >= self._refresh_ttl_seconds:
return await self.refresh()
return ks

View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from src.infrastructure.vault import JwtKeyStore
logger = logging.getLogger(__name__)
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
scheduler = AsyncIOScheduler()
scheduler.add_job(
store.refresh,
trigger=IntervalTrigger(seconds=refresh_seconds),
id="jwt_keys_refresh",
replace_existing=True,
max_instances=1,
coalesce=True,
misfire_grace_time=60,
)
scheduler.start()
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
return scheduler

View File

@@ -0,0 +1,17 @@
from __future__ import annotations
import hvac
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
client = hvac.Client(url=url, token=token, timeout=timeout)
if not client.is_authenticated():
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
return client
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
secret = client.secrets.kv.v2.read_secret_version(
mount_point=mount_point,
path=path,
)
return secret["data"]["data"]