feat(account): GET /me user endpoint only, disable cache and extra routers
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
2
src/infrastructure/cache/__init__.py
vendored
Normal file
2
src/infrastructure/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.cache.client import create_redis_client
|
||||
from src.infrastructure.cache.keydb_client import KeydbCache
|
||||
16
src/infrastructure/cache/client.py
vendored
Normal file
16
src/infrastructure/cache/client.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=50,
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
52
src/infrastructure/cache/keydb_client.py
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
import orjson
|
||||
from redis.asyncio.client import Redis
|
||||
from src.application.contracts import ICache
|
||||
from src.application.domain.entities.user import UserEntity
|
||||
|
||||
|
||||
class KeydbCache(ICache):
|
||||
USER_PREFIX = 'user:me'
|
||||
|
||||
def __init__(self, redis_client: Redis):
|
||||
self._r = redis_client
|
||||
|
||||
async def set(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl))
|
||||
|
||||
async def set_nx(self, key: str, value: str, ttl: int) -> bool:
|
||||
return bool(await self._r.set(key, value, ex=ttl, nx=True))
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
return await self._r.get(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
return (await self._r.delete(key)) > 0
|
||||
|
||||
async def get_user(self, user_id: str) -> dict | None:
|
||||
raw = await self._r.get(f'{self.USER_PREFIX}:{user_id}')
|
||||
if raw is None:
|
||||
return None
|
||||
return orjson.loads(raw)
|
||||
|
||||
async def set_user(self, user_id: str, user: UserEntity, ttl: int = 300) -> None:
|
||||
data = orjson.dumps({
|
||||
'id': user.id,
|
||||
'email': user.email,
|
||||
'first_name': user.first_name,
|
||||
'middle_name': user.middle_name,
|
||||
'last_name': user.last_name,
|
||||
'birth_date': str(user.birth_date) if user.birth_date else None,
|
||||
'crypto_wallet': user.crypto_wallet,
|
||||
'phone': user.phone,
|
||||
'bik': user.bik,
|
||||
'account_number': user.account_number,
|
||||
'card_number': user.card_number,
|
||||
'inn': user.inn,
|
||||
'kyc_verified': user.kyc_verified,
|
||||
'is_deleted': user.is_deleted,
|
||||
'created_at': user.created_at.isoformat() if user.created_at else None,
|
||||
'updated_at': user.updated_at.isoformat() if user.updated_at else None,
|
||||
'kyc_verified_at': user.kyc_verified_at.isoformat() if user.kyc_verified_at else None,
|
||||
})
|
||||
await self._r.set(f'{self.USER_PREFIX}:{user_id}', data, ex=ttl)
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import settings
|
||||
155
src/infrastructure/config/settings.py
Normal file
155
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
env_file = find_dotenv(".env")
|
||||
if env_file:
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
VAULT_ADDR: str = Field(default="http://localhost:8200")
|
||||
VAULT_TOKEN: str = Field(..., description="Vault token is required")
|
||||
VAULT_MOUNT_POINT: str = Field(default="secrets")
|
||||
|
||||
VAULT_JWT_KID_PATH: str = "jwt/kid"
|
||||
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
|
||||
JWT_KEYS_REFRESH_SECONDS: int = 3600
|
||||
|
||||
DATABASE_HOST: str
|
||||
DATABASE_PORT: int = Field(default=5432, ge=1, le=65535)
|
||||
DATABASE_NAME: str
|
||||
DATABASE_USER: str
|
||||
DATABASE_PASSWORD: str
|
||||
|
||||
DATABASE_POOL_SIZE: int = 10
|
||||
DATABASE_MAX_OVERFLOW: int = 20
|
||||
DATABASE_POOL_TIMEOUT: int = 30
|
||||
DATABASE_POOL_RECYCLE: int = 3600
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
CSRF_SECRET_KEY: str = Field(
|
||||
default="change-me-change-me-change-me-change-me",
|
||||
min_length=32,
|
||||
)
|
||||
|
||||
CSRF_COOKIE_SECURE: bool = False
|
||||
CSRF_COOKIE_HTTPONLY: bool = True
|
||||
CSRF_COOKIE_SAMESITE: Literal["Lax", "Strict", "None"] = "Lax"
|
||||
CSRF_COOKIE_PATH: str = "/"
|
||||
CSRF_COOKIE_DOMAIN: str | None = None
|
||||
|
||||
DOCS_USERNAME: str = "admin"
|
||||
DOCS_PASSWORD: str = "admin"
|
||||
|
||||
JWT_ACCESS_TTL_SECONDS: int = 15 * 60
|
||||
JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60
|
||||
JWT_ISSUER: str | None = None
|
||||
JWT_AUDIENCE: str | None = None
|
||||
JWT_ALGORITHM: str = "RS256"
|
||||
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
REDIS_PASSWORD: str | None = None
|
||||
REDIS_DB: int = 0
|
||||
|
||||
RABBIT_HOST: str = "localhost"
|
||||
RABBIT_PORT: int = 5672
|
||||
RABBIT_USER: str = "guest"
|
||||
RABBIT_PASSWORD: str = "guest"
|
||||
RABBIT_VHOST: str = "/"
|
||||
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code"
|
||||
|
||||
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def load_from_vault(cls, data: dict):
|
||||
addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200"
|
||||
token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN")
|
||||
mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets"
|
||||
|
||||
if not token:
|
||||
raise RuntimeError("VAULT_TOKEN is required")
|
||||
|
||||
client = create_hvac_client(url=addr, token=token, timeout=5)
|
||||
|
||||
def safe_read(path: str) -> dict:
|
||||
try:
|
||||
return read_kv2_secret(client=client, mount_point=mount, path=path)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
database = safe_read("database")
|
||||
rabbitmq = safe_read("rabbitmq")
|
||||
csrf = safe_read("csrf")
|
||||
|
||||
if database:
|
||||
required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"]
|
||||
missing = [k for k in required if k not in database]
|
||||
if missing:
|
||||
raise RuntimeError(f"Vault database secret missing keys {missing}")
|
||||
|
||||
data["DATABASE_HOST"] = database["HOST"]
|
||||
data["DATABASE_PORT"] = database["PORT"]
|
||||
data["DATABASE_NAME"] = database["NAME"]
|
||||
data["DATABASE_USER"] = database["USER"]
|
||||
data["DATABASE_PASSWORD"] = database["PASSWORD"]
|
||||
|
||||
if rabbitmq:
|
||||
data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST"))
|
||||
data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT"))
|
||||
data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER"))
|
||||
data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD"))
|
||||
data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST"))
|
||||
|
||||
if csrf:
|
||||
data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY"))
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
|
||||
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||
)
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
|
||||
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
|
||||
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
|
||||
|
||||
@property
|
||||
def EXCLUDED_PATHS(self) -> List[str]:
|
||||
return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
1
src/infrastructure/context_vars/__init__.py
Normal file
1
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.context_vars.trace_id import trace_id_var
|
||||
4
src/infrastructure/context_vars/trace_id.py
Normal file
4
src/infrastructure/context_vars/trace_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')
|
||||
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
1
src/infrastructure/database/decorators/__init__.py
Normal file
1
src/infrastructure/database/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.decorators.transactional import transactional
|
||||
15
src/infrastructure/database/decorators/transactional.py
Normal file
15
src/infrastructure/database/decorators/transactional.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, TypeVar, ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(method)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
async with self._unit_of_work:
|
||||
return await method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
6
src/infrastructure/database/models/__init__.py
Normal file
6
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.sessions import Session
|
||||
|
||||
__all__ = ['Base', 'UserModel', 'Session']
|
||||
|
||||
19
src/infrastructure/database/models/base.py
Normal file
19
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
|
||||
for col in self.__table__.columns)
|
||||
return f"<{class_name}({attributes})>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
|
||||
for col in self.__table__.columns
|
||||
if getattr(self, col.name) is not None)
|
||||
return f"{class_name}({attributes})"
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
16
src/infrastructure/database/models/mixins/audit.py
Normal file
16
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
|
||||
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))
|
||||
50
src/infrastructure/database/models/sessions.py
Normal file
50
src/infrastructure/database/models/sessions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
from src.infrastructure.database.models import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
|
||||
|
||||
|
||||
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
sid: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
default=lambda: str(ULID()),
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
first_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
last_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
last_seen_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
|
||||
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
|
||||
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)
|
||||
28
src/infrastructure/database/models/user.py
Normal file
28
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Boolean,Date,String,DateTime
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
2
src/infrastructure/database/repositories/__init__.py
Normal file
2
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
from src.infrastructure.database.repositories.session_repository import SessionRepository
|
||||
198
src/infrastructure/database/repositories/session_repository.py
Normal file
198
src/infrastructure/database/repositories/session_repository.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities import SessionEntity
|
||||
from src.application.abstractions.repositories import ISessionRepository
|
||||
from src.infrastructure.database.models import Session
|
||||
|
||||
|
||||
class SessionRepository(ISessionRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(select(Session).where(Session.sid == sid))
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
|
||||
if m is None:
|
||||
m = Session(
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
revoked_at=None,
|
||||
last_seen_at=now,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
first_ip=ip,
|
||||
last_ip=ip,
|
||||
)
|
||||
self._session.add(m)
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
else:
|
||||
m.sid = sid
|
||||
m.revoked_at = None
|
||||
m.last_seen_at = now
|
||||
m.refresh_jti_hash = refresh_jti_hash
|
||||
m.refresh_expires_at = refresh_expires_at
|
||||
m.user_agent = user_agent
|
||||
m.last_ip = ip
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
# Интерфейс требует None -> просто делаем update и flush
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(revoked_at=now)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
|
||||
values = {'last_seen_at': now}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
res = await self._session.execute(
|
||||
update(Session)
|
||||
.where(
|
||||
Session.sid == sid,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
|
||||
)
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
return (res.rowcount or 0) > 0
|
||||
118
src/infrastructure/database/repositories/user_repository.py
Normal file
118
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def _get_active_user(self, user_id: str) -> UserModel:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.id == user_id,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
user: UserModel | None = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
self._logger.warning(f'User not found with user_id {user_id}')
|
||||
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found')
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _to_entity(user: UserModel) -> UserEntity:
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
password_hash=None,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> UserEntity:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
return self._to_entity(user)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def _update_field(self, user_id: str, **fields: object) -> UserEntity:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
for key, value in fields.items():
|
||||
setattr(user, key, value)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(user)
|
||||
return self._to_entity(user)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def set_phone(self, user_id: str, phone: str) -> UserEntity:
|
||||
return await self._update_field(user_id, phone=phone)
|
||||
|
||||
async def set_bank_details(self, user_id: str, **fields: str) -> UserEntity:
|
||||
return await self._update_field(user_id, **fields)
|
||||
|
||||
async def set_crypto_wallet(self, user_id: str, wallet_address: str) -> UserEntity:
|
||||
return await self._update_field(user_id, crypto_wallet=wallet_address)
|
||||
|
||||
async def get_password_hash(self, user_id: str) -> str:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
return user.password_hash
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def set_password(self, user_id: str, password_hash: str) -> UserEntity:
|
||||
return await self._update_field(user_id, password_hash=password_hash)
|
||||
|
||||
async def set_email(self, user_id: str, email: str) -> UserEntity:
|
||||
return await self._update_field(user_id, email=email)
|
||||
|
||||
async def email_exists(self, email: str) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none() is not None
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
42
src/infrastructure/database/unit_of_work.py
Normal file
42
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._session_repository: ISessionRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository:
|
||||
if self._session_repository is None:
|
||||
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
|
||||
return self._session_repository
|
||||
28
src/infrastructure/logger/__init__.py
Normal file
28
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat
|
||||
from src.application.domain.enums import LogLevel
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
log_levels = {
|
||||
'DEBUG': LogLevel.DEBUG,
|
||||
'INFO': LogLevel.INFO,
|
||||
'WARNING': LogLevel.WARNING,
|
||||
'ERROR': LogLevel.ERROR,
|
||||
'CRITICAL': LogLevel.CRITICAL,
|
||||
'EXCEPTION': LogLevel.EXCEPTION,
|
||||
}
|
||||
|
||||
log_formats = {
|
||||
'JSON': LogFormat.JSON,
|
||||
'TEXT': LogFormat.TEXT,
|
||||
}
|
||||
|
||||
logger = Logger(
|
||||
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
|
||||
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
|
||||
)
|
||||
|
||||
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = json.dumps(log_data, ensure_ascii=False)
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
1
src/infrastructure/messanger/__init__.py
Normal file
1
src/infrastructure/messanger/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.messanger.rabbit_client import RabbitClient
|
||||
72
src/infrastructure/messanger/rabbit_client.py
Normal file
72
src/infrastructure/messanger/rabbit_client.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from typing import Any, Mapping
|
||||
from faststream.rabbit import RabbitBroker
|
||||
from src.application.contracts import IQueueMessanger
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
class RabbitClient(IQueueMessanger):
|
||||
def __init__(self) -> None:
|
||||
self._broker = RabbitBroker(
|
||||
settings.RABBIT_URL,
|
||||
)
|
||||
self._connected = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
if self._connected:
|
||||
return
|
||||
await self._broker.connect()
|
||||
self._connected = True
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._connected:
|
||||
return
|
||||
await self._broker.close()
|
||||
self._connected = False
|
||||
|
||||
async def _ensure_connected(self) -> None:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
async def publish_to_queue(
|
||||
self,
|
||||
queue: str,
|
||||
message: Any,
|
||||
*,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
queue=queue,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
async def publish(
|
||||
self,
|
||||
message: Any,
|
||||
*,
|
||||
exchange: str,
|
||||
routing_key: str,
|
||||
persist: bool | None = None,
|
||||
headers: Mapping[str, Any] | None = None,
|
||||
correlation_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
await self._ensure_connected()
|
||||
|
||||
await self._broker.publish(
|
||||
message,
|
||||
exchange=exchange,
|
||||
routing_key=routing_key,
|
||||
persist=settings.RABBIT_PUBLISH_PERSIST if persist is None else persist,
|
||||
headers=headers,
|
||||
correlation_id=correlation_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
3
src/infrastructure/security/__init__.py
Normal file
3
src/infrastructure/security/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.security.jwt import JwtService
|
||||
from src.infrastructure.security.csrf import CsrfService
|
||||
from src.infrastructure.security.hash import HashService
|
||||
81
src/infrastructure/security/csrf.py
Normal file
81
src/infrastructure/security/csrf.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
import secrets
|
||||
from typing import Any, Optional, Mapping
|
||||
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
|
||||
from src.application.contracts import ICsrfService
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
|
||||
|
||||
class CsrfService(ICsrfService):
|
||||
COOKIE_NAME = 'csrf_token'
|
||||
HEADER_NAME = 'X-CSRF-Token'
|
||||
SALT = 'csrf'
|
||||
TTL_SECONDS = 3600
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._serializer = URLSafeTimedSerializer(
|
||||
secret_key=settings.CSRF_SECRET_KEY,
|
||||
salt=self.SALT,
|
||||
)
|
||||
|
||||
@property
|
||||
def cookie_name(self) -> str:
|
||||
return self.COOKIE_NAME
|
||||
|
||||
@property
|
||||
def header_name(self) -> str:
|
||||
return self.HEADER_NAME
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> int:
|
||||
return self.TTL_SECONDS
|
||||
|
||||
def issue(self, subject: Optional[str] = None) -> str:
|
||||
payload = {
|
||||
'sub': subject,
|
||||
'nonce': secrets.token_urlsafe(32),
|
||||
}
|
||||
return self._serializer.dumps(payload)
|
||||
|
||||
def verify(self, token: str, expected_subject: Optional[str] = None) -> dict[str, Any]:
|
||||
try:
|
||||
data = self._serializer.loads(token, max_age=self.TTL_SECONDS)
|
||||
except SignatureExpired:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token expired',
|
||||
)
|
||||
except BadSignature:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token invalid',
|
||||
)
|
||||
|
||||
if expected_subject is not None and data.get('sub') != expected_subject:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token subject mismatch',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract(self, cookies: Mapping[str, str], headers: Mapping[str, str]) -> tuple[Optional[str], Optional[str]]:
|
||||
cookie_token = cookies.get(self.COOKIE_NAME)
|
||||
header_token = headers.get(self.HEADER_NAME)
|
||||
return cookie_token, header_token
|
||||
|
||||
def verify_pair(self, cookie_token: Optional[str], header_token: Optional[str], expected_subject: Optional[str] = None) -> None:
|
||||
if not cookie_token or not header_token:
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token missing',
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise ApplicationException(
|
||||
status_code=403,
|
||||
message='CSRF token mismatch',
|
||||
)
|
||||
|
||||
self.verify(cookie_token, expected_subject=expected_subject)
|
||||
17
src/infrastructure/security/hash.py
Normal file
17
src/infrastructure/security/hash.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import bcrypt
|
||||
from src.application.contracts import IHashService, ILogger
|
||||
|
||||
|
||||
class HashService(IHashService):
|
||||
|
||||
def __init__(self, logger: ILogger):
|
||||
self._logger = logger
|
||||
|
||||
async def hash(self, value: str) -> str:
|
||||
hashed_value = bcrypt.hashpw(value.encode(), bcrypt.gensalt())
|
||||
self._logger.info(f'Hash value {hashed_value.decode()}')
|
||||
return hashed_value.decode()
|
||||
|
||||
async def verify(self, hashed_value: str, plain_value: str) -> bool:
|
||||
self._logger.info(f'Hash value {hashed_value[:10]}')
|
||||
return bcrypt.checkpw(plain_value.encode(), hashed_value.encode())
|
||||
109
src/infrastructure/security/jwt.py
Normal file
109
src/infrastructure/security/jwt.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
from jose import jwt, ExpiredSignatureError, JWTError
|
||||
from src.application.contracts import ILogger, IJwtService
|
||||
from src.application.domain.dto import AccessTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
|
||||
class JwtService(IJwtService):
|
||||
def __init__(self, logger: ILogger, key_store: JwtKeyStore) -> None:
|
||||
self._logger = logger
|
||||
self._key_store = key_store
|
||||
|
||||
async def decode_access_token(self, token: str) -> AccessTokenPayload:
|
||||
payload = await self._decode_and_verify(token)
|
||||
|
||||
if payload.get('type') != 'access':
|
||||
self._logger.warning(f'Access token invalid type received_type={payload.get('type')}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token type')
|
||||
|
||||
try:
|
||||
return AccessTokenPayload(
|
||||
sub=str(payload['sub']),
|
||||
type='access',
|
||||
sid=str(payload['sid']),
|
||||
iat=int(payload['iat']),
|
||||
nbf=int(payload['nbf']),
|
||||
exp=int(payload['exp']),
|
||||
iss=payload.get('iss'),
|
||||
aud=payload.get('aud'),
|
||||
)
|
||||
except KeyError as exception:
|
||||
self._logger.warning(f'Access token missing claim error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message=f'Missing token claim: {exception}')
|
||||
|
||||
async def _decode_and_verify(self, token: str) -> dict:
|
||||
kid: str | None = None
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
|
||||
kid = header.get('kid')
|
||||
if not kid:
|
||||
self._logger.warning(f'JWT header missing kid header={header}')
|
||||
raise ApplicationException(status_code=401, message='Missing token header: kid')
|
||||
|
||||
received_alg = header.get('alg')
|
||||
if received_alg != settings.JWT_ALGORITHM:
|
||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token algorithm')
|
||||
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
|
||||
await self._key_store.refresh()
|
||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
||||
|
||||
if not public_pem:
|
||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Unknown token kid')
|
||||
|
||||
options = {
|
||||
'verify_signature': True,
|
||||
'verify_exp': True,
|
||||
'verify_nbf': True,
|
||||
'verify_iat': True,
|
||||
'verify_aud': bool(settings.JWT_AUDIENCE),
|
||||
'verify_iss': bool(settings.JWT_ISSUER),
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'require_sub': True,
|
||||
'leeway': 10,
|
||||
}
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
public_pem,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
audience=settings.JWT_AUDIENCE or None,
|
||||
issuer=settings.JWT_ISSUER or None,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if 'sid' not in payload:
|
||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: sid')
|
||||
|
||||
if 'type' not in payload:
|
||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||
raise ApplicationException(status_code=401, message='Missing token claim: type')
|
||||
|
||||
return payload
|
||||
|
||||
except ExpiredSignatureError as exception:
|
||||
self._logger.info(f'JWT expired kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Token expired')
|
||||
|
||||
except ApplicationException:
|
||||
raise
|
||||
|
||||
except JWTError as exception:
|
||||
self._logger.warning(f'JWT decode failed kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=401, message='Invalid token')
|
||||
|
||||
except Exception as exception:
|
||||
self._logger.error(f'Unexpected JWT decode error kid={kid} error={str(exception)}')
|
||||
raise ApplicationException(status_code=500, message='JWT decode failed')
|
||||
1
src/infrastructure/utils/__init__.py
Normal file
1
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.utils.instance_id import generate_instance_id
|
||||
14
src/infrastructure/utils/instance_id.py
Normal file
14
src/infrastructure/utils/instance_id.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
"""
|
||||
Generate a process-wide instance id in ULID format.
|
||||
|
||||
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
|
||||
"""
|
||||
|
||||
|
||||
return str(ULID())
|
||||
|
||||
|
||||
3
src/infrastructure/vault/__init__.py
Normal file
3
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
|
||||
from src.infrastructure.vault.keys import JwtKeyStore
|
||||
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
||||
113
src/infrastructure/vault/keys.py
Normal file
113
src/infrastructure/vault/keys.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
|
||||
_instance: 'JwtKeyStore | None' = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str,
|
||||
vault_token: str,
|
||||
mount_point: str,
|
||||
kid_path: str = 'jwt/kid',
|
||||
kids_prefix: str = 'jwt/kids',
|
||||
timeout_seconds: int = 5,
|
||||
refresh_ttl_seconds: int = 60,
|
||||
):
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self._vault_addr = vault_addr
|
||||
self._vault_token = vault_token
|
||||
self._timeout = timeout_seconds
|
||||
|
||||
self._mount = mount_point
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
|
||||
self._refresh_ttl_seconds = refresh_ttl_seconds
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._keyset: JwtPublicKeySet | None = None
|
||||
self._last_refresh_at: datetime | None = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
if cls._instance is None:
|
||||
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
|
||||
return cls._instance
|
||||
|
||||
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
||||
client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout)
|
||||
|
||||
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
||||
active_kid = kids.get('active')
|
||||
previous_kid = kids.get('previous')
|
||||
|
||||
if not active_kid:
|
||||
raise RuntimeError('Vault jwt/kid secret missing "active"')
|
||||
|
||||
active = self._read_public_key_sync(client, str(active_kid))
|
||||
|
||||
previous = None
|
||||
if previous_kid and previous_kid != active_kid:
|
||||
previous = self._read_public_key_sync(client, str(previous_kid))
|
||||
|
||||
return JwtPublicKeySet(active=active, previous=previous)
|
||||
|
||||
def _read_public_key_sync(self, client, kid: str) -> JwtPublicKey:
|
||||
data = read_kv2_secret(
|
||||
client=client,
|
||||
mount_point=self._mount,
|
||||
path=f'{self._kids_prefix}/{kid}',
|
||||
)
|
||||
pub = data.get('public_key')
|
||||
if not pub:
|
||||
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
|
||||
return JwtPublicKey(kid=kid, public_key_pem=pub)
|
||||
|
||||
async def refresh(self) -> JwtPublicKeySet:
|
||||
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
||||
async with self._lock:
|
||||
self._keyset = keyset
|
||||
self._last_refresh_at = datetime.now(timezone.utc)
|
||||
return keyset
|
||||
|
||||
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.public_keys_by_kid().get(kid)
|
||||
|
||||
async def last_refresh_at(self) -> datetime | None:
|
||||
async with self._lock:
|
||||
return self._last_refresh_at
|
||||
|
||||
async def _get_or_refresh(self) -> JwtPublicKeySet:
|
||||
async with self._lock:
|
||||
ks = self._keyset
|
||||
last = self._last_refresh_at
|
||||
|
||||
if ks is None:
|
||||
return await self.refresh()
|
||||
|
||||
if last is None:
|
||||
return await self.refresh()
|
||||
|
||||
age = (datetime.now(timezone.utc) - last).total_seconds()
|
||||
if age >= self._refresh_ttl_seconds:
|
||||
return await self.refresh()
|
||||
|
||||
return ks
|
||||
23
src/infrastructure/vault/scheduler.py
Normal file
23
src/infrastructure/vault/scheduler.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from src.infrastructure.vault import JwtKeyStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_jwt_keys_scheduler(store: JwtKeyStore, *, refresh_seconds: int = 3600) -> AsyncIOScheduler:
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(
|
||||
store.refresh,
|
||||
trigger=IntervalTrigger(seconds=refresh_seconds),
|
||||
id="jwt_keys_refresh",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("JWT keys scheduler started (interval=%s seconds)", refresh_seconds)
|
||||
return scheduler
|
||||
17
src/infrastructure/vault/utils.py
Normal file
17
src/infrastructure/vault/utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import hvac
|
||||
|
||||
|
||||
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
|
||||
client = hvac.Client(url=url, token=token, timeout=timeout)
|
||||
if not client.is_authenticated():
|
||||
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
|
||||
return client
|
||||
|
||||
|
||||
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point=mount_point,
|
||||
path=path,
|
||||
)
|
||||
return secret["data"]["data"]
|
||||
Reference in New Issue
Block a user