Initial commit

This commit is contained in:
2026-04-16 13:51:10 +03:00
commit a0724af6f1
38 changed files with 2453 additions and 0 deletions

View File

@@ -0,0 +1 @@
from src.application.contracts.i_logger import ILogger

View File

@@ -0,0 +1,68 @@
from typing import Protocol, Optional, Callable
from src.application.domain.enums.log_format import LogFormat
from src.application.domain.enums.log_level import LogLevel
class ILogger(Protocol):
"""Interface for synchronous logger with ContextVar support for trace_id."""
log_format: LogFormat
min_level: LogLevel
id_generator: Optional[Callable[[], str]]
instance_id: str
def set_format(self, log_format: LogFormat) -> None:
"""Set log format using LogFormat enum"""
...
def set_min_level(self, level: LogLevel) -> None:
"""Set minimum log level"""
...
def new_trace_id(self) -> str:
"""Create and set new trace_id in context"""
...
def set_trace_id(self, trace_id: str) -> None:
"""Set existing trace_id in context"""
...
def get_trace_id(self) -> str:
"""Get current trace_id from context"""
...
def clear_trace_id(self) -> None:
"""Clear the trace_id in the context"""
...
def set_instance_id(self, instance_id: str) -> None:
"""Set service instance id (ULID recommended)"""
...
def get_instance_id(self) -> str:
"""Get current service instance id"""
...
def debug(self, message: str) -> None:
"""Log debug message"""
...
def info(self, message: str) -> None:
"""Log info message"""
...
def warning(self, message: str) -> None:
"""Log warning message"""
...
def error(self, message: str) -> None:
"""Log error message"""
...
def critical(self, message: str) -> None:
"""Log critical message"""
...
def exception(self, message: str) -> None:
"""Log exception with traceback"""
...

View File

@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
class ISender(ABC):
@abstractmethod
async def send(self, to: str, subject: str, body: str, plain: str | None = None) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,2 @@
from src.application.domain.enums.log_level import LogLevel
from src.application.domain.enums.log_format import LogFormat

View File

@@ -0,0 +1,7 @@
from enum import Enum
class LogFormat(Enum):
"""Enum for supported log formats"""
TEXT = 'text'
JSON = 'json'

View File

@@ -0,0 +1,54 @@
from enum import Enum
class LogLevel(Enum):
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
CRITICAL = 50
EXCEPTION = 60
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"[{self.value}, '{self.name}']"
def __eq__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value == other.value
if isinstance(other, int):
return self.value == other
return NotImplemented
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
def __lt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value < other.value
if isinstance(other, int):
return self.value < other
return NotImplemented
def __le__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value <= other.value
if isinstance(other, int):
return self.value <= other
return NotImplemented
def __gt__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value > other.value
if isinstance(other, int):
return self.value > other
return NotImplemented
def __ge__(self, other: object) -> bool:
if isinstance(other, LogLevel):
return self.value >= other.value
if isinstance(other, int):
return self.value >= other
return NotImplemented

View File

@@ -0,0 +1 @@
from src.application.domain.exceptions.application_exceptions import ApplicationException

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Mapping
class ApplicationException(Exception):
def __init__(
self,
status_code: int,
message: str,
headers: Mapping[str, str] | None = None,
):
super().__init__(message)
self.status_code = status_code
self.message = message
self.headers = headers
def __str__(self):
return f"{self.status_code}: {self.message}"

View File

@@ -0,0 +1 @@
from src.infrastructure.config.settings import settings

View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from functools import lru_cache
from typing import List, Literal
import os
from dotenv import load_dotenv, find_dotenv
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
env_file = find_dotenv('.env')
if env_file:
load_dotenv(env_file)
def _vault_key_get(mapping: dict, key: str, default=None):
ku = key.upper()
kl = key.lower()
if ku in mapping:
return mapping[ku]
if kl in mapping:
return mapping[kl]
return default
class Settings(BaseSettings):
VAULT_ADDR: str = Field(default='http://localhost:8200')
VAULT_ROLE_ID: str = Field(..., description='Vault AppRole role_id')
VAULT_SECRET_ID: str = Field(..., description='Vault AppRole secret_id')
VAULT_AUTH_MOUNT: str = Field(default='approle')
VAULT_MOUNT_POINT: str = Field(default='secrets')
DOCS_USERNAME: str = "admin"
DOCS_PASSWORD: str = "admin"
SMTP_FROM: str = ""
SMTP_HOST: str = "localhost"
SMTP_PASSWORD: str = ""
SMTP_PORT: int = 587
SMTP_USER: str = ""
RABBIT_HOST: str = "localhost"
RABBIT_PORT: int = 5672
RABBIT_USER: str = "guest"
RABBIT_PASSWORD: str = "guest"
RABBIT_VHOST: str = "/"
RABBIT_PUBLISH_PERSIST: bool = True
RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code"
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT"
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=True,
extra="ignore",
)
@model_validator(mode="before")
@classmethod
def load_from_vault(cls, data: dict):
addr = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
secret_id = data.get('VAULT_SECRET_ID') or os.getenv('VAULT_SECRET_ID')
auth_mount = data.get('VAULT_AUTH_MOUNT') or os.getenv('VAULT_AUTH_MOUNT') or 'approle'
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'secrets'
if not role_id or not secret_id:
raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID are required')
client = create_hvac_client(
url=addr,
role_id=role_id,
secret_id=secret_id,
auth_mount_point=auth_mount,
timeout=5,
)
def safe_read(path: str) -> dict:
try:
return read_kv2_secret(client=client, mount_point=mount, path=path)
except Exception:
return {}
rabbitmq = safe_read("rabbitmq")
email = safe_read("email")
if rabbitmq:
data['RABBIT_HOST'] = _vault_key_get(rabbitmq, 'HOST', data.get('RABBIT_HOST'))
data['RABBIT_PORT'] = _vault_key_get(rabbitmq, 'PORT', data.get('RABBIT_PORT'))
data['RABBIT_USER'] = _vault_key_get(rabbitmq, 'USER', data.get('RABBIT_USER'))
data['RABBIT_PASSWORD'] = _vault_key_get(rabbitmq, 'PASSWORD', data.get('RABBIT_PASSWORD'))
data['RABBIT_VHOST'] = _vault_key_get(rabbitmq, 'VHOST', data.get('RABBIT_VHOST'))
if email:
data['SMTP_FROM'] = _vault_key_get(email, 'FROM', data.get('SMTP_FROM'))
data['SMTP_HOST'] = _vault_key_get(email, 'HOST', data.get('SMTP_HOST'))
data['SMTP_PASSWORD'] = _vault_key_get(email, 'PASSWORD', data.get('SMTP_PASSWORD'))
data['SMTP_PORT'] = _vault_key_get(email, 'PORT', data.get('SMTP_PORT'))
data['SMTP_USER'] = _vault_key_get(email, 'USER', data.get('SMTP_USER'))
return data
@property
def DATABASE_URL(self) -> str:
return (
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
)
@property
def REDIS_URL(self) -> str:
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
@property
def RABBIT_URL(self) -> str:
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
@property
def EXCLUDED_PATHS(self) -> List[str]:
return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"]
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()
settings = get_settings()

View File

@@ -0,0 +1 @@
from src.infrastructure.context_vars.trace_id import trace_id_var

View File

@@ -0,0 +1,4 @@
from contextvars import ContextVar
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')

View File

@@ -0,0 +1,30 @@
from anyio.functools import lru_cache
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat
from src.application.domain.enums import LogLevel
from src.infrastructure.config.settings import settings
from src.infrastructure.logger.logger import Logger
log_levels = {
'DEBUG': LogLevel.DEBUG,
'INFO': LogLevel.INFO,
'WARNING': LogLevel.WARNING,
'ERROR': LogLevel.ERROR,
'CRITICAL': LogLevel.CRITICAL,
'EXCEPTION': LogLevel.EXCEPTION,
}
log_formats = {
'JSON': LogFormat.JSON,
'TEXT': LogFormat.TEXT,
}
logger = Logger(
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
)
@lru_cache
def get_logger() -> ILogger:
return logger

View File

@@ -0,0 +1,129 @@
import traceback
import inspect
import sys
import json
from datetime import datetime
from typing import Callable, Optional, Any
from ulid import ULID
from src.application.contracts import ILogger
from src.application.domain.enums import LogFormat, LogLevel
from src.infrastructure.context_vars import trace_id_var
class Logger(ILogger):
_instance = None
__default_format = LogFormat.JSON
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
if cls._instance is None:
cls._instance = super(Logger, cls).__new__(cls)
return cls._instance
def __init__(
self,
log_format: LogFormat = __default_format,
min_level: LogLevel = LogLevel.INFO,
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
instance_id: str = "N/A",
):
self.log_format = log_format
self.min_level = min_level
self.id_generator = id_generator
self.instance_id = instance_id
def set_instance_id(self, instance_id: str) -> None:
self.instance_id = instance_id
def get_instance_id(self) -> str:
return self.instance_id
def set_format(self, log_format: LogFormat) -> None:
if not isinstance(log_format, LogFormat):
raise ValueError("Log format must be an instance of LogFormat enum")
self.log_format = log_format
def set_min_level(self, level: LogLevel) -> None:
self.min_level = level
def new_trace_id(self) -> str:
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
trace_id_var.set(trace_id)
return trace_id
def set_trace_id(self, trace_id: str) -> None:
trace_id_var.set(trace_id)
def get_trace_id(self) -> str:
return trace_id_var.get()
def clear_trace_id(self) -> None:
trace_id_var.set("N/A")
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
current_frame = inspect.currentframe()
if (
current_frame
and current_frame.f_back
and current_frame.f_back.f_back
and current_frame.f_back.f_back.f_back
):
frame = current_frame.f_back.f_back.f_back
filename = frame.f_code.co_filename
line_number = frame.f_lineno
else:
filename = "unknown"
line_number = 0
log_data = {
"timestamp": datetime.now().isoformat(),
"level": level.name,
"instance_id": self.instance_id,
"file": filename,
"line": line_number,
"trace_id": trace_id_var.get(),
"message": message,
}
if level == LogLevel.EXCEPTION:
log_data["exception"] = traceback.format_exc()
return log_data
def _log(self, level: LogLevel, message: str) -> None:
if level >= self.min_level:
log_data = self._prepare_log_data(level, message)
if self.log_format == LogFormat.JSON:
log_message = json.dumps(log_data, ensure_ascii=False)
else:
log_message = (
f"{log_data['timestamp']} - {log_data['level']} - "
f"{log_data['instance_id']} - {log_data['trace_id']} - "
f"{log_data['file']}:{log_data['line']} - "
f"{log_data['message']}"
)
if "exception" in log_data:
log_message += f"\nTraceback:\n{log_data['exception']}"
self._write(log_message)
def _write(self, message: str) -> None:
sys.stdout.write(message + "\n")
def debug(self, message: str) -> None:
self._log(LogLevel.DEBUG, message)
def info(self, message: str) -> None:
self._log(LogLevel.INFO, message)
def warning(self, message: str) -> None:
self._log(LogLevel.WARNING, message)
def error(self, message: str) -> None:
self._log(LogLevel.ERROR, message)
def critical(self, message: str) -> None:
self._log(LogLevel.CRITICAL, message)
def exception(self, message: str) -> None:
self._log(LogLevel.EXCEPTION, message)

View File

@@ -0,0 +1,21 @@
from anyio.functools import lru_cache
from src.application.contracts.i_sender import ISender
from src.infrastructure.config import settings
from src.infrastructure.mail.render import TemplateRenderer
from src.infrastructure.mail.sender import EmailSender
@lru_cache(maxsize=1)
def get_renderer() -> TemplateRenderer:
return TemplateRenderer()
@lru_cache(maxsize=1)
def get_email_sender() -> ISender:
return EmailSender(
host=settings.SMTP_HOST,
port=settings.SMTP_PORT,
username=str(settings.SMTP_USER),
password=settings.SMTP_PASSWORD,
from_addr=settings.SMTP_FROM,
)

View File

@@ -0,0 +1,17 @@
from pathlib import Path
from jinja2 import Environment, FileSystemLoader, select_autoescape
class TemplateRenderer:
def __init__(self, templates_dir: Path | str | None = None):
if templates_dir is None:
templates_dir = Path(__file__).parent / "templates"
self._env = Environment(
loader=FileSystemLoader(templates_dir),
autoescape=select_autoescape(["html"]),
trim_blocks=True,
lstrip_blocks=True,
)
def render(self, template_name: str, **kwargs: object) -> str:
return self._env.get_template(template_name).render(**kwargs)

View File

@@ -0,0 +1,46 @@
import aiosmtplib
from email.message import EmailMessage
from src.application.contracts.i_sender import ISender
class EmailSender(ISender):
def __init__(
self,
host: str,
port: int,
username: str,
password: str,
from_addr: str,
use_tls: bool = True,
timeout: int = 10,
):
self._host = host
self._port = port
self._username = username
self._password = password
self._from_addr = from_addr
self._use_tls = use_tls
self._timeout = timeout
async def send(self, to: str, subject: str, body: str, plain: str | None = None) -> None:
message = EmailMessage()
message["From"] = self._from_addr
message["To"] = to
message["Subject"] = subject
if plain:
message.set_content(plain)
message.add_alternative(body, subtype="html")
else:
message.set_content(body, subtype="html")
await aiosmtplib.send(
message,
hostname=self._host,
port=self._port,
username=self._username,
password=self._password,
use_tls=True,
timeout=self._timeout,
)

View File

@@ -0,0 +1,171 @@
<!doctype html>
<html lang="ru">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta name="color-scheme" content="dark" />
<title>{{ subject }}</title>
</head>
<body style="margin:0;padding:0;background:#0E1126;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%;">
<!--[if mso]><style>table,td{font-family:Arial,Helvetica,sans-serif !important;}</style><![endif]-->
<!-- Preheader -->
<div style="display:none;max-height:0;overflow:hidden;mso-hide:all;">
Ваш код: {{ code }}. Действует {{ ttl_minutes }} мин. &#8199;&#65279;&#847;
</div>
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
style="background:#0E1126;padding:32px 16px;">
<tr>
<td align="center">
<!-- Outer card 600px -->
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="600"
style="max-width:600px;width:100%;border-radius:20px;overflow:hidden;
border:1px solid rgba(93,4,217,0.30);">
<!-- ====== HEADER — gradient bar ====== -->
<tr>
<td style="padding:28px 32px;
background:linear-gradient(135deg,#260E59 0%,#5D04D9 50%,#056CF2 100%);">
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%">
<tr>
<td>
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:#ffffff;font-size:22px;font-weight:800;
letter-spacing:0.4px;">
{{ brand }}
</div>
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:rgba(255,255,255,0.80);font-size:13px;
margin-top:6px;letter-spacing:0.2px;">
Подтверждение · Безопасность
</div>
</td>
<td align="right" valign="middle">
<!-- Shield icon via CSS -->
<div style="width:44px;height:44px;border-radius:12px;
background:rgba(255,255,255,0.12);
text-align:center;line-height:44px;font-size:22px;">
🔐
</div>
</td>
</tr>
</table>
</td>
</tr>
<!-- ====== BODY ====== -->
<tr>
<td style="padding:28px 32px;background:#0E1126;">
<!-- Greeting -->
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:#ffffff;font-size:20px;font-weight:700;">
Ваш код подтверждения
</div>
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:rgba(255,255,255,0.70);font-size:14px;
line-height:22px;margin-top:10px;">
Введите этот код в приложении, чтобы завершить действие.
Никому не сообщайте его.
</div>
<!-- ── Code card ── -->
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
style="margin-top:24px;">
<tr>
<td style="background:linear-gradient(160deg,#260E59 0%,#1a0a3e 100%);
border:1px solid rgba(93,4,217,0.50);
border-radius:16px;padding:24px 20px;text-align:center;">
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:rgba(255,255,255,0.55);font-size:11px;
letter-spacing:2px;text-transform:uppercase;">
код подтверждения
</div>
<div style="font-family:'SF Mono','Cascadia Code','Fira Code',monospace;
color:#ffffff;font-size:38px;font-weight:800;
letter-spacing:10px;margin-top:12px;
padding:12px 0;
background:linear-gradient(90deg,#05C7F2,#5D04D9,#056CF2);
-webkit-background-clip:text;
-webkit-text-fill-color:transparent;
background-clip:text;">
{{ code }}
</div>
<!-- TTL pill -->
<table role="presentation" cellpadding="0" cellspacing="0" border="0"
style="margin:14px auto 0;">
<tr>
<td style="background:rgba(5,199,242,0.12);
border:1px solid rgba(5,199,242,0.30);
border-radius:20px;padding:6px 16px;">
<span style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:#05C7F2;font-size:13px;font-weight:600;">
⏱ Действует {{ ttl_minutes }} мин
</span>
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- Warning -->
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
style="margin-top:20px;">
<tr>
<td style="background:rgba(93,4,217,0.08);
border-left:3px solid #5D04D9;
border-radius:0 10px 10px 0;padding:14px 16px;">
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:rgba(255,255,255,0.70);font-size:13px;line-height:20px;">
⚠️ Если вы не запрашивали код — просто проигнорируйте
это письмо. Никогда не сообщайте код третьим лицам.
</div>
</td>
</tr>
</table>
</td>
</tr>
<!-- ====== DIVIDER ====== -->
<tr>
<td style="padding:0 32px;">
<div style="height:1px;
background:linear-gradient(90deg,transparent,rgba(5,199,242,0.25),transparent);">
</div>
</td>
</tr>
<!-- ====== FOOTER ====== -->
<tr>
<td style="padding:20px 32px 28px;background:#0E1126;">
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%">
<tr>
<td align="right" valign="bottom"
style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
color:rgba(255,255,255,0.30);font-size:11px;">
© {{ year }} {{ brand }}<br />
Ref: {{ trace_id }}
</td>
</tr>
</table>
</td>
</tr>
</table>
<!-- /Outer card -->
</td>
</tr>
</table>
</body>
</html>

View File

@@ -0,0 +1,7 @@
{{ brand }}
Ваш код подтверждения: {{ code }}
Срок действия: {{ ttl_minutes }} минут
Если вы не запрашивали код — игнорируйте это письмо.
Ref: {{ trace_id }}

View File

@@ -0,0 +1,11 @@
from faststream.rabbit import RabbitBroker
from faststream.rabbit.schemas import RabbitQueue
from src.infrastructure.config import settings
broker = RabbitBroker(settings.RABBIT_URL)
email_code_queue = RabbitQueue(
name=settings.RABBIT_EMAIL_CODE_QUEUE,
durable=True,
)

View File

@@ -0,0 +1 @@
from src.infrastructure.utils.instance_id import generate_instance_id

View File

@@ -0,0 +1,14 @@
from ulid import ULID
def generate_instance_id() -> str:
"""
Generate a process-wide instance id in ULID format.
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
"""
return str(ULID())

View File

@@ -0,0 +1,2 @@
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
from src.infrastructure.vault.keys import JwtKeyStore

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional, Dict
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
@dataclass(frozen=True)
class JwtKeyPair:
kid: str
private_key_pem: str
public_key_pem: str
@dataclass(frozen=True)
class JwtKeySet:
active: JwtKeyPair
previous: Optional[JwtKeyPair] = None
def public_keys_by_kid(self) -> Dict[str, str]:
out = {self.active.kid: self.active.public_key_pem}
if self.previous:
out[self.previous.kid] = self.previous.public_key_pem
return out
class JwtKeyStore:
_instance: "JwtKeyStore | None" = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(
self,
*,
vault_addr: str,
vault_role_id: str,
vault_secret_id: str,
vault_auth_mount: str = 'approle',
mount_point: str,
kid_path: str = 'jwt/kid',
kids_prefix: str = 'jwt/kids',
timeout_seconds: int = 5,
):
if getattr(self, '_initialized', False):
return
self._vault_addr = vault_addr
self._vault_role_id = vault_role_id
self._vault_secret_id = vault_secret_id
self._vault_auth_mount = vault_auth_mount
self._timeout = timeout_seconds
self._mount = mount_point
self._kid_path = kid_path
self._kids_prefix = kids_prefix
self._lock = asyncio.Lock()
self._keyset: JwtKeySet | None = None
self._last_refresh_at: datetime | None = None
self._initialized = True
@classmethod
def get_instance(cls) -> 'JwtKeyStore':
if cls._instance is None:
raise ApplicationException(status_code=500, message="JwtKeyStore not initialized")
return cls._instance
def _read_keyset_sync(self) -> JwtKeySet:
client = create_hvac_client(
url=self._vault_addr,
role_id=self._vault_role_id,
secret_id=self._vault_secret_id,
auth_mount_point=self._vault_auth_mount,
timeout=self._timeout,
)
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
active_kid = kids.get("active")
previous_kid = kids.get("previous")
print("VAULT kids:", {"active": active_kid, "previous": previous_kid})
if not active_kid:
raise RuntimeError("Vault jwt/kid secret missing 'active'")
active_pair = self._read_keypair_sync(client, active_kid)
prev_pair = None
if previous_kid and previous_kid != active_kid:
prev_pair = self._read_keypair_sync(client, previous_kid)
return JwtKeySet(active=active_pair, previous=prev_pair)
def _read_keypair_sync(self, client, kid: str) -> JwtKeyPair:
data = read_kv2_secret(
client=client,
mount_point=self._mount,
path=f"{self._kids_prefix}/{kid}",
)
priv = data.get("private_key")
pub = data.get("public_key")
if not priv or not pub:
raise RuntimeError(f"Vault jwt/kids/{kid} missing private_key/public_key")
return JwtKeyPair(kid=kid, private_key_pem=priv, public_key_pem=pub)
async def refresh(self) -> JwtKeySet:
keyset = await asyncio.to_thread(self._read_keyset_sync)
async with self._lock:
self._keyset = keyset
self._last_refresh_at = datetime.now(timezone.utc)
return keyset
async def get_signing_key(self) -> tuple[str, str]:
ks = await self._get_or_refresh()
return ks.active.kid, ks.active.private_key_pem
async def get_public_key_for_kid(self, kid: str) -> str | None:
ks = await self._get_or_refresh()
return ks.public_keys_by_kid().get(kid)
async def last_refresh_at(self) -> datetime | None:
async with self._lock:
return self._last_refresh_at
async def _get_or_refresh(self) -> JwtKeySet:
async with self._lock:
ks = self._keyset
return ks if ks else await self.refresh()

View File

@@ -0,0 +1,31 @@
from __future__ import annotations
import hvac
def create_hvac_client(
*,
url: str,
role_id: str,
secret_id: str,
auth_mount_point: str = 'approle',
timeout: int = 5,
) -> hvac.Client:
client = hvac.Client(url=url, timeout=timeout)
client.auth.approle.login(
role_id=role_id,
secret_id=secret_id,
mount_point=auth_mount_point,
)
if not client.is_authenticated():
raise RuntimeError(
'Vault AppRole login failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID, VAULT_AUTH_MOUNT',
)
return client
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
secret = client.secrets.kv.v2.read_secret_version(
mount_point=mount_point,
path=path,
)
return secret["data"]["data"]

103
src/main.py Normal file
View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from contextlib import asynccontextmanager
import secrets
from typing import AsyncGenerator
from fastapi import Depends, FastAPI, status, APIRouter
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from src.application.domain.exceptions import ApplicationException
from src.infrastructure.utils import generate_instance_id
from src.infrastructure.logger import logger
from src.infrastructure.config import settings
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
from src.presentation.messaging.code import code_broker
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
security = HTTPBasic()
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
if not (user_ok and pass_ok):
raise ApplicationException(
status_code=status.HTTP_401_UNAUTHORIZED,
message='Unauthorized',
headers={'WWW-Authenticate': 'Basic'},
)
return credentials
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
instance_id = generate_instance_id()
logger.set_instance_id(instance_id)
logger.info(f'Auth service instance started with id {instance_id}')
yield
logger.info(f'Auth service instance ended with id {instance_id}')
app: FastAPI = FastAPI(
redoc_url=None,
docs_url=None,
lifespan=lifespan,
title='Bitforce. Notify Service',
version='1.0.0',
description='',
license_info={
'name': 'MIT',
'url': 'https://opensource.org/licenses/MIT',
},
)
app.add_exception_handler(ApplicationException, application_exception_handler)
app.add_exception_handler(Exception, unhandled_exception_handler)
v1_router = APIRouter(prefix='/v1')
app.include_router(v1_router)
app.include_router(code_broker)
# Added middleware
app.add_middleware(TraceIDMiddleware, logger=logger)
app.add_middleware(
SecurityHeadersMiddleware,
hsts=True,
hsts_preload=False,
frame_options='DENY',
referrer_policy='strict-origin-when-cross-origin',
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
)
@app.get('/docs', include_in_schema=False)
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
"""Custom Swagger documentation, optionally protected with basic authentication."""
return get_swagger_ui_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
)
@app.get('/redoc', include_in_schema=False)
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
"""Custom ReDoc documentation, optionally protected with basic authentication."""
return get_redoc_html(
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
)
@app.post('/ping')
async def ping() -> dict[str, str]:
return {
'message': 'pong',
'status': 'ok',
}

View File

@@ -0,0 +1,2 @@
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
from src.presentation.handlers.application_handler import application_exception_handler

View File

@@ -0,0 +1,17 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from src.application.domain.exceptions import ApplicationException
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
detail = exc.message
if 500 <= exc.status_code:
detail = "Internal Server Error"
return ORJSONResponse(
status_code=exc.status_code,
content={"detail": detail},
headers=dict(exc.headers) if exc.headers else None,
)

View File

@@ -0,0 +1,12 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from starlette import status
from src.infrastructure.logger import logger
async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJSONResponse:
logger.exception(f'Unhandled exception: {type(exc).__name__}')
return ORJSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'detail': 'Internal Server Error'},
)

View File

@@ -0,0 +1,78 @@
from fastapi import Depends
from faststream.rabbit.fastapi import RabbitRouter, RabbitMessage
from src.infrastructure.config import settings
from src.infrastructure.logger import logger
from src.infrastructure.mail import TemplateRenderer, get_email_sender, get_renderer
from src.infrastructure.mail.sender import EmailSender
from datetime import datetime
from typing import Literal, Optional
from pydantic import BaseModel, EmailStr, Field
from src.infrastructure.rabbit.broker import email_code_queue
code_broker = RabbitRouter(settings.RABBIT_URL)
class Metadata(BaseModel):
trace_id: Optional[str] = None
source: str
timestamp: datetime
message_id: str
class Payload(BaseModel):
email: EmailStr
code: str = Field(min_length=1, max_length=64)
ttl_seconds: int = Field(gt=0, lt=24 * 3600)
class LoginCodeCreated(BaseModel):
event: Literal["login", "registration", "bank_details_update"]
payload: Payload
metadata: Metadata
@code_broker.subscriber(email_code_queue)
async def consume_email_code(
msg_body: LoginCodeCreated,
message: RabbitMessage,
sender: EmailSender = Depends(get_email_sender),
renderer: TemplateRenderer = Depends(get_renderer),
):
trace_id = (
(message.headers or {}).get("trace_id")
or message.correlation_id
or msg_body.metadata.trace_id
)
logger.info(
f"received event={msg_body.event} "
f"email={msg_body.payload.email} "
f"ttl={msg_body.payload.ttl_seconds} "
f"trace_id={trace_id}"
)
html = renderer.render(
"email_code.html",
subject="Код подтверждения",
code=msg_body.payload.code,
ttl_minutes=msg_body.payload.ttl_seconds // 60,
brand="Bitforce",
trace_id=trace_id,
year=datetime.now().year,
)
text = renderer.render(
"email_code.txt",
code=msg_body.payload.code,
ttl_minutes=msg_body.payload.ttl_seconds // 60,
brand="Bitforce",
trace_id=trace_id,
)
await sender.send(
to=msg_body.payload.email,
subject="Код подтверждения",
body=html,
plain=text,
)

View File

@@ -0,0 +1,2 @@
from src.presentation.middleware.trace_id import TraceIDMiddleware
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware

View File

@@ -0,0 +1,51 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app,
*,
hsts: bool = True,
hsts_max_age: int = 31536000, # 1 год
hsts_include_subdomains: bool = True,
hsts_preload: bool = False,
frame_options: str = 'DENY', # или 'SAMEORIGIN'
referrer_policy: str = 'strict-origin-when-cross-origin',
content_security_policy: str | None = None,
):
super().__init__(app)
self.hsts = hsts
self.hsts_max_age = hsts_max_age
self.hsts_include_subdomains = hsts_include_subdomains
self.hsts_preload = hsts_preload
self.frame_options = frame_options
self.referrer_policy = referrer_policy
self.csp = content_security_policy
async def dispatch(self, request: Request, call_next) -> Response:
response: Response = await call_next(request)
if request.url.path in ('/docs', '/redoc', '/openapi.json'):
return response
if self.hsts and request.url.scheme == 'https':
hsts = f'max-age={self.hsts_max_age}'
if self.hsts_include_subdomains:
hsts += '; includeSubDomains'
if self.hsts_preload:
hsts += '; preload'
response.headers['Strict-Transport-Security'] = hsts
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = self.frame_options
response.headers['Referrer-Policy'] = self.referrer_policy
if self.csp:
response.headers['Content-Security-Policy'] = self.csp
return response

View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from typing import Optional
from contextvars import Token
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from ulid import ULID
from src.application.contracts import ILogger
from src.infrastructure.config import settings
from src.infrastructure.context_vars import trace_id_var
class TraceIDMiddleware:
def __init__(
self,
app: ASGIApp,
logger: ILogger,
response_header_name: str = "X-Trace-ID",
attach_response_header: bool = True,
) -> None:
self.app = app
self.logger = logger
self.response_header_name = response_header_name
self.attach_response_header = attach_response_header
def _is_excluded(self, path: str) -> bool:
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope)
if self._is_excluded(request.url.path):
await self.app(scope, receive, send)
return
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
if not trace_id:
trace_id = str(ULID())
request.state.trace_id = trace_id
token: Token = trace_id_var.set(trace_id)
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
status_code_holder: dict[str, Optional[int]] = {"status": None}
async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
status_code_holder["status"] = int(message["status"])
if self.attach_response_header:
headers = list(message.get("headers", []))
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
message["headers"] = headers
await send(message)
try:
await self.app(scope, receive, send_wrapper)
finally:
status = status_code_holder["status"]
status_part = f"{status}" if status is not None else "unknown"
self.logger.debug(
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
)
trace_id_var.reset(token)
# from __future__ import annotations
# from typing import Optional
# from starlette.requests import Request
# from starlette.types import ASGIApp, Message, Receive, Scope, Send
# from ulid import ULID
# from src.application.contracts import ILogger
# from src.infrastructure.config.settings import settings
#
#
# class TraceIDMiddleware:
# def __init__(
# self,
# app: ASGIApp,
# logger: ILogger,
# response_header_name: str = 'X-Trace-ID',
# attach_response_header: bool = True,
# ) -> None:
# self.app = app
# self.logger = logger
# self.response_header_name = response_header_name
# self.attach_response_header = attach_response_header
#
# def _is_excluded(self, path: str) -> bool:
# return any(path == p or path.startswith(p.rstrip('/') + '/') for p in settings.EXCLUDED_PATHS)
#
# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
# if scope['type'] != 'http':
# await self.app(scope, receive, send)
# return
#
# request = Request(scope)
#
# if self._is_excluded(request.url.path):
# await self.app(scope, receive, send)
# return
#
# trace_id = request.headers.get('X-Trace-ID') or request.headers.get('X-Request-ID')
# if not trace_id:
# trace_id = str(ULID())
#
# request.state.trace_id = trace_id
# self.logger.set_trace_id(trace_id)
#
# self.logger.debug(f'Request started: {request.method} {request.url} - TraceID: {trace_id}')
#
# status_code_holder: dict[str, Optional[int]] = {'status': None}
#
# async def send_wrapper(message: Message) -> None:
# if message['type'] == 'http.response.start':
# status_code_holder['status'] = int(message['status'])
#
# if self.attach_response_header:
# headers = list(message.get('headers', []))
# headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
# message['headers'] = headers
# await send(message)
#
# try:
# await self.app(scope, receive, send_wrapper)
# finally:
# status = status_code_holder['status']
# status_part = f'{status}' if status is not None else 'unknown'
# self.logger.debug(f'Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}')
# self.logger.clear_trace_id()