Initial commit
This commit is contained in:
1
src/application/contracts/__init__.py
Normal file
1
src/application/contracts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.contracts.i_logger import ILogger
|
||||
68
src/application/contracts/i_logger.py
Normal file
68
src/application/contracts/i_logger.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from typing import Protocol, Optional, Callable
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
|
||||
|
||||
class ILogger(Protocol):
|
||||
"""Interface for synchronous logger with ContextVar support for trace_id."""
|
||||
|
||||
log_format: LogFormat
|
||||
min_level: LogLevel
|
||||
id_generator: Optional[Callable[[], str]]
|
||||
instance_id: str
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
"""Set log format using LogFormat enum"""
|
||||
...
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
"""Set minimum log level"""
|
||||
...
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
"""Create and set new trace_id in context"""
|
||||
...
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
"""Set existing trace_id in context"""
|
||||
...
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
"""Get current trace_id from context"""
|
||||
...
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
"""Clear the trace_id in the context"""
|
||||
...
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
"""Set service instance id (ULID recommended)"""
|
||||
...
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
"""Get current service instance id"""
|
||||
...
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log debug message"""
|
||||
...
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log info message"""
|
||||
...
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log warning message"""
|
||||
...
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log error message"""
|
||||
...
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""Log critical message"""
|
||||
...
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
"""Log exception with traceback"""
|
||||
...
|
||||
8
src/application/contracts/i_sender.py
Normal file
8
src/application/contracts/i_sender.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ISender(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, to: str, subject: str, body: str, plain: str | None = None) -> None:
|
||||
raise NotImplementedError
|
||||
2
src/application/domain/enums/__init__.py
Normal file
2
src/application/domain/enums/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.application.domain.enums.log_level import LogLevel
|
||||
from src.application.domain.enums.log_format import LogFormat
|
||||
7
src/application/domain/enums/log_format.py
Normal file
7
src/application/domain/enums/log_format.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
"""Enum for supported log formats"""
|
||||
TEXT = 'text'
|
||||
JSON = 'json'
|
||||
54
src/application/domain/enums/log_level.py
Normal file
54
src/application/domain/enums/log_level.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARNING = 30
|
||||
ERROR = 40
|
||||
CRITICAL = 50
|
||||
EXCEPTION = 60
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"[{self.value}, '{self.name}']"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value == other.value
|
||||
if isinstance(other, int):
|
||||
return self.value == other
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value < other.value
|
||||
if isinstance(other, int):
|
||||
return self.value < other
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value <= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value <= other
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value > other.value
|
||||
if isinstance(other, int):
|
||||
return self.value > other
|
||||
return NotImplemented
|
||||
|
||||
def __ge__(self, other: object) -> bool:
|
||||
if isinstance(other, LogLevel):
|
||||
return self.value >= other.value
|
||||
if isinstance(other, int):
|
||||
return self.value >= other
|
||||
return NotImplemented
|
||||
1
src/application/domain/exceptions/__init__.py
Normal file
1
src/application/domain/exceptions/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.application.domain.exceptions.application_exceptions import ApplicationException
|
||||
18
src/application/domain/exceptions/application_exceptions.py
Normal file
18
src/application/domain/exceptions/application_exceptions.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
class ApplicationException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status_code}: {self.message}"
|
||||
1
src/infrastructure/config/__init__.py
Normal file
1
src/infrastructure/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.config.settings import settings
|
||||
135
src/infrastructure/config/settings.py
Normal file
135
src/infrastructure/config/settings.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
env_file = find_dotenv('.env')
|
||||
if env_file:
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
def _vault_key_get(mapping: dict, key: str, default=None):
|
||||
ku = key.upper()
|
||||
kl = key.lower()
|
||||
if ku in mapping:
|
||||
return mapping[ku]
|
||||
if kl in mapping:
|
||||
return mapping[kl]
|
||||
return default
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
VAULT_ADDR: str = Field(default='http://localhost:8200')
|
||||
VAULT_ROLE_ID: str = Field(..., description='Vault AppRole role_id')
|
||||
VAULT_SECRET_ID: str = Field(..., description='Vault AppRole secret_id')
|
||||
VAULT_AUTH_MOUNT: str = Field(default='approle')
|
||||
VAULT_MOUNT_POINT: str = Field(default='secrets')
|
||||
|
||||
DOCS_USERNAME: str = "admin"
|
||||
DOCS_PASSWORD: str = "admin"
|
||||
|
||||
SMTP_FROM: str = ""
|
||||
SMTP_HOST: str = "localhost"
|
||||
SMTP_PASSWORD: str = ""
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USER: str = ""
|
||||
|
||||
RABBIT_HOST: str = "localhost"
|
||||
RABBIT_PORT: int = 5672
|
||||
RABBIT_USER: str = "guest"
|
||||
RABBIT_PASSWORD: str = "guest"
|
||||
RABBIT_VHOST: str = "/"
|
||||
|
||||
RABBIT_PUBLISH_PERSIST: bool = True
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code"
|
||||
|
||||
LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
||||
LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=True,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def load_from_vault(cls, data: dict):
|
||||
addr = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
|
||||
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
|
||||
secret_id = data.get('VAULT_SECRET_ID') or os.getenv('VAULT_SECRET_ID')
|
||||
auth_mount = data.get('VAULT_AUTH_MOUNT') or os.getenv('VAULT_AUTH_MOUNT') or 'approle'
|
||||
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'secrets'
|
||||
|
||||
if not role_id or not secret_id:
|
||||
raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID are required')
|
||||
|
||||
client = create_hvac_client(
|
||||
url=addr,
|
||||
role_id=role_id,
|
||||
secret_id=secret_id,
|
||||
auth_mount_point=auth_mount,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def safe_read(path: str) -> dict:
|
||||
try:
|
||||
return read_kv2_secret(client=client, mount_point=mount, path=path)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
rabbitmq = safe_read("rabbitmq")
|
||||
email = safe_read("email")
|
||||
|
||||
|
||||
if rabbitmq:
|
||||
data['RABBIT_HOST'] = _vault_key_get(rabbitmq, 'HOST', data.get('RABBIT_HOST'))
|
||||
data['RABBIT_PORT'] = _vault_key_get(rabbitmq, 'PORT', data.get('RABBIT_PORT'))
|
||||
data['RABBIT_USER'] = _vault_key_get(rabbitmq, 'USER', data.get('RABBIT_USER'))
|
||||
data['RABBIT_PASSWORD'] = _vault_key_get(rabbitmq, 'PASSWORD', data.get('RABBIT_PASSWORD'))
|
||||
data['RABBIT_VHOST'] = _vault_key_get(rabbitmq, 'VHOST', data.get('RABBIT_VHOST'))
|
||||
|
||||
if email:
|
||||
data['SMTP_FROM'] = _vault_key_get(email, 'FROM', data.get('SMTP_FROM'))
|
||||
data['SMTP_HOST'] = _vault_key_get(email, 'HOST', data.get('SMTP_HOST'))
|
||||
data['SMTP_PASSWORD'] = _vault_key_get(email, 'PASSWORD', data.get('SMTP_PASSWORD'))
|
||||
data['SMTP_PORT'] = _vault_key_get(email, 'PORT', data.get('SMTP_PORT'))
|
||||
data['SMTP_USER'] = _vault_key_get(email, 'USER', data.get('SMTP_USER'))
|
||||
|
||||
return data
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}"
|
||||
f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}"
|
||||
)
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
|
||||
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/")
|
||||
return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}"
|
||||
|
||||
@property
|
||||
def EXCLUDED_PATHS(self) -> List[str]:
|
||||
return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
1
src/infrastructure/context_vars/__init__.py
Normal file
1
src/infrastructure/context_vars/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.context_vars.trace_id import trace_id_var
|
||||
4
src/infrastructure/context_vars/trace_id.py
Normal file
4
src/infrastructure/context_vars/trace_id.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
|
||||
trace_id_var: ContextVar[str] = ContextVar('trace_id', default='N/A')
|
||||
30
src/infrastructure/logger/__init__.py
Normal file
30
src/infrastructure/logger/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from anyio.functools import lru_cache
|
||||
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat
|
||||
from src.application.domain.enums import LogLevel
|
||||
from src.infrastructure.config.settings import settings
|
||||
from src.infrastructure.logger.logger import Logger
|
||||
|
||||
log_levels = {
|
||||
'DEBUG': LogLevel.DEBUG,
|
||||
'INFO': LogLevel.INFO,
|
||||
'WARNING': LogLevel.WARNING,
|
||||
'ERROR': LogLevel.ERROR,
|
||||
'CRITICAL': LogLevel.CRITICAL,
|
||||
'EXCEPTION': LogLevel.EXCEPTION,
|
||||
}
|
||||
|
||||
log_formats = {
|
||||
'JSON': LogFormat.JSON,
|
||||
'TEXT': LogFormat.TEXT,
|
||||
}
|
||||
|
||||
logger = Logger(
|
||||
min_level=log_levels.get(settings.LOG_LEVEL, LogLevel.INFO),
|
||||
log_format=log_formats.get(settings.LOG_FORMAT, LogFormat.JSON),
|
||||
)
|
||||
|
||||
@lru_cache
|
||||
def get_logger() -> ILogger:
|
||||
return logger
|
||||
129
src/infrastructure/logger/logger.py
Normal file
129
src/infrastructure/logger/logger.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import sys
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Callable, Optional, Any
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.enums import LogFormat, LogLevel
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class Logger(ILogger):
|
||||
_instance = None
|
||||
__default_format = LogFormat.JSON
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "Logger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Logger, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_format: LogFormat = __default_format,
|
||||
min_level: LogLevel = LogLevel.INFO,
|
||||
id_generator: Optional[Callable[[], str]] = lambda: str(ULID()),
|
||||
instance_id: str = "N/A",
|
||||
):
|
||||
self.log_format = log_format
|
||||
self.min_level = min_level
|
||||
self.id_generator = id_generator
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_instance_id(self, instance_id: str) -> None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def set_format(self, log_format: LogFormat) -> None:
|
||||
if not isinstance(log_format, LogFormat):
|
||||
raise ValueError("Log format must be an instance of LogFormat enum")
|
||||
self.log_format = log_format
|
||||
|
||||
def set_min_level(self, level: LogLevel) -> None:
|
||||
self.min_level = level
|
||||
|
||||
def new_trace_id(self) -> str:
|
||||
trace_id = str(ULID()) if self.id_generator is None else self.id_generator()
|
||||
trace_id_var.set(trace_id)
|
||||
return trace_id
|
||||
|
||||
def set_trace_id(self, trace_id: str) -> None:
|
||||
trace_id_var.set(trace_id)
|
||||
|
||||
def get_trace_id(self) -> str:
|
||||
return trace_id_var.get()
|
||||
|
||||
def clear_trace_id(self) -> None:
|
||||
trace_id_var.set("N/A")
|
||||
|
||||
def _prepare_log_data(self, level: LogLevel, message: str) -> dict[str, Any]:
|
||||
current_frame = inspect.currentframe()
|
||||
if (
|
||||
current_frame
|
||||
and current_frame.f_back
|
||||
and current_frame.f_back.f_back
|
||||
and current_frame.f_back.f_back.f_back
|
||||
):
|
||||
frame = current_frame.f_back.f_back.f_back
|
||||
filename = frame.f_code.co_filename
|
||||
line_number = frame.f_lineno
|
||||
else:
|
||||
filename = "unknown"
|
||||
line_number = 0
|
||||
|
||||
log_data = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"level": level.name,
|
||||
"instance_id": self.instance_id,
|
||||
"file": filename,
|
||||
"line": line_number,
|
||||
"trace_id": trace_id_var.get(),
|
||||
"message": message,
|
||||
}
|
||||
|
||||
if level == LogLevel.EXCEPTION:
|
||||
log_data["exception"] = traceback.format_exc()
|
||||
|
||||
return log_data
|
||||
|
||||
def _log(self, level: LogLevel, message: str) -> None:
|
||||
if level >= self.min_level:
|
||||
log_data = self._prepare_log_data(level, message)
|
||||
|
||||
if self.log_format == LogFormat.JSON:
|
||||
log_message = json.dumps(log_data, ensure_ascii=False)
|
||||
else:
|
||||
log_message = (
|
||||
f"{log_data['timestamp']} - {log_data['level']} - "
|
||||
f"{log_data['instance_id']} - {log_data['trace_id']} - "
|
||||
f"{log_data['file']}:{log_data['line']} - "
|
||||
f"{log_data['message']}"
|
||||
)
|
||||
if "exception" in log_data:
|
||||
log_message += f"\nTraceback:\n{log_data['exception']}"
|
||||
|
||||
self._write(log_message)
|
||||
|
||||
def _write(self, message: str) -> None:
|
||||
sys.stdout.write(message + "\n")
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
self._log(LogLevel.DEBUG, message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
self._log(LogLevel.INFO, message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
self._log(LogLevel.WARNING, message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
self._log(LogLevel.ERROR, message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
self._log(LogLevel.CRITICAL, message)
|
||||
|
||||
def exception(self, message: str) -> None:
|
||||
self._log(LogLevel.EXCEPTION, message)
|
||||
21
src/infrastructure/mail/__init__.py
Normal file
21
src/infrastructure/mail/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from anyio.functools import lru_cache
|
||||
from src.application.contracts.i_sender import ISender
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.mail.render import TemplateRenderer
|
||||
from src.infrastructure.mail.sender import EmailSender
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_renderer() -> TemplateRenderer:
|
||||
return TemplateRenderer()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_email_sender() -> ISender:
|
||||
return EmailSender(
|
||||
host=settings.SMTP_HOST,
|
||||
port=settings.SMTP_PORT,
|
||||
username=str(settings.SMTP_USER),
|
||||
password=settings.SMTP_PASSWORD,
|
||||
from_addr=settings.SMTP_FROM,
|
||||
)
|
||||
17
src/infrastructure/mail/render.py
Normal file
17
src/infrastructure/mail/render.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pathlib import Path
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
def __init__(self, templates_dir: Path | str | None = None):
|
||||
if templates_dir is None:
|
||||
templates_dir = Path(__file__).parent / "templates"
|
||||
self._env = Environment(
|
||||
loader=FileSystemLoader(templates_dir),
|
||||
autoescape=select_autoescape(["html"]),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
def render(self, template_name: str, **kwargs: object) -> str:
|
||||
return self._env.get_template(template_name).render(**kwargs)
|
||||
46
src/infrastructure/mail/sender.py
Normal file
46
src/infrastructure/mail/sender.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import aiosmtplib
|
||||
from email.message import EmailMessage
|
||||
from src.application.contracts.i_sender import ISender
|
||||
|
||||
|
||||
class EmailSender(ISender):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
username: str,
|
||||
password: str,
|
||||
from_addr: str,
|
||||
use_tls: bool = True,
|
||||
timeout: int = 10,
|
||||
):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._from_addr = from_addr
|
||||
self._use_tls = use_tls
|
||||
self._timeout = timeout
|
||||
|
||||
async def send(self, to: str, subject: str, body: str, plain: str | None = None) -> None:
|
||||
message = EmailMessage()
|
||||
message["From"] = self._from_addr
|
||||
message["To"] = to
|
||||
message["Subject"] = subject
|
||||
|
||||
if plain:
|
||||
message.set_content(plain)
|
||||
message.add_alternative(body, subtype="html")
|
||||
else:
|
||||
message.set_content(body, subtype="html")
|
||||
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=self._host,
|
||||
port=self._port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
use_tls=True,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
171
src/infrastructure/mail/templates/email_code.html
Normal file
171
src/infrastructure/mail/templates/email_code.html
Normal file
@@ -0,0 +1,171 @@
|
||||
<!doctype html>
|
||||
<html lang="ru">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta name="color-scheme" content="dark" />
|
||||
<title>{{ subject }}</title>
|
||||
</head>
|
||||
<body style="margin:0;padding:0;background:#0E1126;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%;">
|
||||
|
||||
<!--[if mso]><style>table,td{font-family:Arial,Helvetica,sans-serif !important;}</style><![endif]-->
|
||||
|
||||
<!-- Preheader -->
|
||||
<div style="display:none;max-height:0;overflow:hidden;mso-hide:all;">
|
||||
Ваш код: {{ code }}. Действует {{ ttl_minutes }} мин.  ͏
|
||||
</div>
|
||||
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
|
||||
style="background:#0E1126;padding:32px 16px;">
|
||||
<tr>
|
||||
<td align="center">
|
||||
|
||||
<!-- Outer card 600px -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="600"
|
||||
style="max-width:600px;width:100%;border-radius:20px;overflow:hidden;
|
||||
border:1px solid rgba(93,4,217,0.30);">
|
||||
|
||||
<!-- ====== HEADER — gradient bar ====== -->
|
||||
<tr>
|
||||
<td style="padding:28px 32px;
|
||||
background:linear-gradient(135deg,#260E59 0%,#5D04D9 50%,#056CF2 100%);">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%">
|
||||
<tr>
|
||||
<td>
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:#ffffff;font-size:22px;font-weight:800;
|
||||
letter-spacing:0.4px;">
|
||||
{{ brand }}
|
||||
</div>
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:rgba(255,255,255,0.80);font-size:13px;
|
||||
margin-top:6px;letter-spacing:0.2px;">
|
||||
Подтверждение · Безопасность
|
||||
</div>
|
||||
</td>
|
||||
<td align="right" valign="middle">
|
||||
<!-- Shield icon via CSS -->
|
||||
<div style="width:44px;height:44px;border-radius:12px;
|
||||
background:rgba(255,255,255,0.12);
|
||||
text-align:center;line-height:44px;font-size:22px;">
|
||||
🔐
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- ====== BODY ====== -->
|
||||
<tr>
|
||||
<td style="padding:28px 32px;background:#0E1126;">
|
||||
|
||||
<!-- Greeting -->
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:#ffffff;font-size:20px;font-weight:700;">
|
||||
Ваш код подтверждения
|
||||
</div>
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:rgba(255,255,255,0.70);font-size:14px;
|
||||
line-height:22px;margin-top:10px;">
|
||||
Введите этот код в приложении, чтобы завершить действие.
|
||||
Никому не сообщайте его.
|
||||
</div>
|
||||
|
||||
<!-- ── Code card ── -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
|
||||
style="margin-top:24px;">
|
||||
<tr>
|
||||
<td style="background:linear-gradient(160deg,#260E59 0%,#1a0a3e 100%);
|
||||
border:1px solid rgba(93,4,217,0.50);
|
||||
border-radius:16px;padding:24px 20px;text-align:center;">
|
||||
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:rgba(255,255,255,0.55);font-size:11px;
|
||||
letter-spacing:2px;text-transform:uppercase;">
|
||||
код подтверждения
|
||||
</div>
|
||||
|
||||
<div style="font-family:'SF Mono','Cascadia Code','Fira Code',monospace;
|
||||
color:#ffffff;font-size:38px;font-weight:800;
|
||||
letter-spacing:10px;margin-top:12px;
|
||||
padding:12px 0;
|
||||
background:linear-gradient(90deg,#05C7F2,#5D04D9,#056CF2);
|
||||
-webkit-background-clip:text;
|
||||
-webkit-text-fill-color:transparent;
|
||||
background-clip:text;">
|
||||
{{ code }}
|
||||
</div>
|
||||
|
||||
<!-- TTL pill -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0"
|
||||
style="margin:14px auto 0;">
|
||||
<tr>
|
||||
<td style="background:rgba(5,199,242,0.12);
|
||||
border:1px solid rgba(5,199,242,0.30);
|
||||
border-radius:20px;padding:6px 16px;">
|
||||
<span style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:#05C7F2;font-size:13px;font-weight:600;">
|
||||
⏱ Действует {{ ttl_minutes }} мин
|
||||
</span>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<!-- Warning -->
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%"
|
||||
style="margin-top:20px;">
|
||||
<tr>
|
||||
<td style="background:rgba(93,4,217,0.08);
|
||||
border-left:3px solid #5D04D9;
|
||||
border-radius:0 10px 10px 0;padding:14px 16px;">
|
||||
<div style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:rgba(255,255,255,0.70);font-size:13px;line-height:20px;">
|
||||
⚠️ Если вы не запрашивали код — просто проигнорируйте
|
||||
это письмо. Никогда не сообщайте код третьим лицам.
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- ====== DIVIDER ====== -->
|
||||
<tr>
|
||||
<td style="padding:0 32px;">
|
||||
<div style="height:1px;
|
||||
background:linear-gradient(90deg,transparent,rgba(5,199,242,0.25),transparent);">
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
<!-- ====== FOOTER ====== -->
|
||||
<tr>
|
||||
<td style="padding:20px 32px 28px;background:#0E1126;">
|
||||
<table role="presentation" cellpadding="0" cellspacing="0" border="0" width="100%">
|
||||
<tr>
|
||||
<td align="right" valign="bottom"
|
||||
style="font-family:'Segoe UI',Arial,Helvetica,sans-serif;
|
||||
color:rgba(255,255,255,0.30);font-size:11px;">
|
||||
© {{ year }} {{ brand }}<br />
|
||||
Ref: {{ trace_id }}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
<!-- /Outer card -->
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
7
src/infrastructure/mail/templates/email_code.txt
Normal file
7
src/infrastructure/mail/templates/email_code.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
{{ brand }}
|
||||
Ваш код подтверждения: {{ code }}
|
||||
Срок действия: {{ ttl_minutes }} минут
|
||||
|
||||
Если вы не запрашивали код — игнорируйте это письмо.
|
||||
|
||||
Ref: {{ trace_id }}
|
||||
11
src/infrastructure/rabbit/broker.py
Normal file
11
src/infrastructure/rabbit/broker.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from faststream.rabbit import RabbitBroker
|
||||
from faststream.rabbit.schemas import RabbitQueue
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
broker = RabbitBroker(settings.RABBIT_URL)
|
||||
|
||||
email_code_queue = RabbitQueue(
|
||||
name=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
durable=True,
|
||||
)
|
||||
1
src/infrastructure/utils/__init__.py
Normal file
1
src/infrastructure/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.utils.instance_id import generate_instance_id
|
||||
14
src/infrastructure/utils/instance_id.py
Normal file
14
src/infrastructure/utils/instance_id.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
def generate_instance_id() -> str:
|
||||
"""
|
||||
Generate a process-wide instance id in ULID format.
|
||||
|
||||
ULID is 26 chars (Crockford Base32) and lexicographically sortable by time.
|
||||
"""
|
||||
|
||||
|
||||
return str(ULID())
|
||||
|
||||
|
||||
2
src/infrastructure/vault/__init__.py
Normal file
2
src/infrastructure/vault/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
|
||||
from src.infrastructure.vault.keys import JwtKeyStore
|
||||
139
src/infrastructure/vault/keys.py
Normal file
139
src/infrastructure/vault/keys.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtKeyPair:
|
||||
kid: str
|
||||
private_key_pem: str
|
||||
public_key_pem: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JwtKeySet:
|
||||
active: JwtKeyPair
|
||||
previous: Optional[JwtKeyPair] = None
|
||||
|
||||
def public_keys_by_kid(self) -> Dict[str, str]:
|
||||
out = {self.active.kid: self.active.public_key_pem}
|
||||
if self.previous:
|
||||
out[self.previous.kid] = self.previous.public_key_pem
|
||||
return out
|
||||
|
||||
|
||||
class JwtKeyStore:
|
||||
|
||||
_instance: "JwtKeyStore | None" = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vault_addr: str,
|
||||
vault_role_id: str,
|
||||
vault_secret_id: str,
|
||||
vault_auth_mount: str = 'approle',
|
||||
mount_point: str,
|
||||
kid_path: str = 'jwt/kid',
|
||||
kids_prefix: str = 'jwt/kids',
|
||||
timeout_seconds: int = 5,
|
||||
):
|
||||
if getattr(self, '_initialized', False):
|
||||
return
|
||||
|
||||
self._vault_addr = vault_addr
|
||||
self._vault_role_id = vault_role_id
|
||||
self._vault_secret_id = vault_secret_id
|
||||
self._vault_auth_mount = vault_auth_mount
|
||||
self._timeout = timeout_seconds
|
||||
|
||||
self._mount = mount_point
|
||||
self._kid_path = kid_path
|
||||
self._kids_prefix = kids_prefix
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._keyset: JwtKeySet | None = None
|
||||
self._last_refresh_at: datetime | None = None
|
||||
|
||||
self._initialized = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> 'JwtKeyStore':
|
||||
|
||||
if cls._instance is None:
|
||||
raise ApplicationException(status_code=500, message="JwtKeyStore not initialized")
|
||||
|
||||
return cls._instance
|
||||
|
||||
def _read_keyset_sync(self) -> JwtKeySet:
|
||||
client = create_hvac_client(
|
||||
url=self._vault_addr,
|
||||
role_id=self._vault_role_id,
|
||||
secret_id=self._vault_secret_id,
|
||||
auth_mount_point=self._vault_auth_mount,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
|
||||
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
||||
active_kid = kids.get("active")
|
||||
previous_kid = kids.get("previous")
|
||||
|
||||
print("VAULT kids:", {"active": active_kid, "previous": previous_kid})
|
||||
|
||||
if not active_kid:
|
||||
raise RuntimeError("Vault jwt/kid secret missing 'active'")
|
||||
|
||||
active_pair = self._read_keypair_sync(client, active_kid)
|
||||
|
||||
prev_pair = None
|
||||
if previous_kid and previous_kid != active_kid:
|
||||
prev_pair = self._read_keypair_sync(client, previous_kid)
|
||||
|
||||
return JwtKeySet(active=active_pair, previous=prev_pair)
|
||||
|
||||
def _read_keypair_sync(self, client, kid: str) -> JwtKeyPair:
|
||||
data = read_kv2_secret(
|
||||
client=client,
|
||||
mount_point=self._mount,
|
||||
path=f"{self._kids_prefix}/{kid}",
|
||||
)
|
||||
priv = data.get("private_key")
|
||||
pub = data.get("public_key")
|
||||
if not priv or not pub:
|
||||
raise RuntimeError(f"Vault jwt/kids/{kid} missing private_key/public_key")
|
||||
return JwtKeyPair(kid=kid, private_key_pem=priv, public_key_pem=pub)
|
||||
|
||||
|
||||
async def refresh(self) -> JwtKeySet:
|
||||
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
||||
async with self._lock:
|
||||
self._keyset = keyset
|
||||
self._last_refresh_at = datetime.now(timezone.utc)
|
||||
return keyset
|
||||
|
||||
async def get_signing_key(self) -> tuple[str, str]:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.active.kid, ks.active.private_key_pem
|
||||
|
||||
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
||||
ks = await self._get_or_refresh()
|
||||
return ks.public_keys_by_kid().get(kid)
|
||||
|
||||
async def last_refresh_at(self) -> datetime | None:
|
||||
async with self._lock:
|
||||
return self._last_refresh_at
|
||||
|
||||
async def _get_or_refresh(self) -> JwtKeySet:
|
||||
async with self._lock:
|
||||
ks = self._keyset
|
||||
return ks if ks else await self.refresh()
|
||||
31
src/infrastructure/vault/utils.py
Normal file
31
src/infrastructure/vault/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
import hvac
|
||||
|
||||
|
||||
def create_hvac_client(
|
||||
*,
|
||||
url: str,
|
||||
role_id: str,
|
||||
secret_id: str,
|
||||
auth_mount_point: str = 'approle',
|
||||
timeout: int = 5,
|
||||
) -> hvac.Client:
|
||||
client = hvac.Client(url=url, timeout=timeout)
|
||||
client.auth.approle.login(
|
||||
role_id=role_id,
|
||||
secret_id=secret_id,
|
||||
mount_point=auth_mount_point,
|
||||
)
|
||||
if not client.is_authenticated():
|
||||
raise RuntimeError(
|
||||
'Vault AppRole login failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID, VAULT_AUTH_MOUNT',
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
def read_kv2_secret(*, client: hvac.Client, mount_point: str, path: str) -> dict:
|
||||
secret = client.secrets.kv.v2.read_secret_version(
|
||||
mount_point=mount_point,
|
||||
path=path,
|
||||
)
|
||||
return secret["data"]["data"]
|
||||
103
src/main.py
Normal file
103
src/main.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
from contextlib import asynccontextmanager
|
||||
import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI, status, APIRouter
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
|
||||
from src.presentation.messaging.code import code_broker
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
pass_ok = secrets.compare_digest(credentials.password, settings.DOCS_PASSWORD)
|
||||
if not (user_ok and pass_ok):
|
||||
raise ApplicationException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message='Unauthorized',
|
||||
headers={'WWW-Authenticate': 'Basic'},
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
instance_id = generate_instance_id()
|
||||
logger.set_instance_id(instance_id)
|
||||
logger.info(f'Auth service instance started with id {instance_id}')
|
||||
yield
|
||||
logger.info(f'Auth service instance ended with id {instance_id}')
|
||||
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
redoc_url=None,
|
||||
docs_url=None,
|
||||
lifespan=lifespan,
|
||||
title='Bitforce. Notify Service',
|
||||
version='1.0.0',
|
||||
description='',
|
||||
license_info={
|
||||
'name': 'MIT',
|
||||
'url': 'https://opensource.org/licenses/MIT',
|
||||
},
|
||||
)
|
||||
|
||||
app.add_exception_handler(ApplicationException, application_exception_handler)
|
||||
app.add_exception_handler(Exception, unhandled_exception_handler)
|
||||
|
||||
v1_router = APIRouter(prefix='/v1')
|
||||
app.include_router(v1_router)
|
||||
app.include_router(code_broker)
|
||||
|
||||
|
||||
# Added middleware
|
||||
app.add_middleware(TraceIDMiddleware, logger=logger)
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts=True,
|
||||
hsts_preload=False,
|
||||
frame_options='DENY',
|
||||
referrer_policy='strict-origin-when-cross-origin',
|
||||
content_security_policy="default-src 'self'; frame-ancestors 'none'; base-uri 'self'; object-src 'none'",
|
||||
)
|
||||
|
||||
|
||||
@app.get('/docs', include_in_schema=False)
|
||||
async def custom_swagger_ui_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
"""Custom Swagger documentation, optionally protected with basic authentication."""
|
||||
return get_swagger_ui_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - Swagger UI',
|
||||
oauth2_redirect_url=getattr(app, 'swagger_ui_oauth2_redirect_url', None),
|
||||
swagger_js_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://unpkg.com/swagger-ui-dist@5/swagger-ui.css',
|
||||
)
|
||||
|
||||
|
||||
@app.get('/redoc', include_in_schema=False)
|
||||
async def custom_redoc_html(_credentials: HTTPBasicCredentials = Depends(verify_credentials)) -> HTMLResponse:
|
||||
"""Custom ReDoc documentation, optionally protected with basic authentication."""
|
||||
return get_redoc_html(
|
||||
openapi_url=getattr(app, 'openapi_url', '/openapi.json'),
|
||||
title=getattr(app, 'title', 'FastAPI') + ' - ReDoc',
|
||||
redoc_js_url='https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js',
|
||||
)
|
||||
|
||||
|
||||
@app.post('/ping')
|
||||
async def ping() -> dict[str, str]:
|
||||
return {
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
2
src/presentation/handlers/__init__.py
Normal file
2
src/presentation/handlers/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
|
||||
from src.presentation.handlers.application_handler import application_exception_handler
|
||||
17
src/presentation/handlers/application_handler.py
Normal file
17
src/presentation/handlers/application_handler.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
|
||||
detail = exc.message
|
||||
if 500 <= exc.status_code:
|
||||
detail = "Internal Server Error"
|
||||
|
||||
return ORJSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": detail},
|
||||
headers=dict(exc.headers) if exc.headers else None,
|
||||
)
|
||||
|
||||
|
||||
12
src/presentation/handlers/unhandled_handler.py
Normal file
12
src/presentation/handlers/unhandled_handler.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from starlette import status
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
|
||||
async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJSONResponse:
|
||||
logger.exception(f'Unhandled exception: {type(exc).__name__}')
|
||||
return ORJSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={'detail': 'Internal Server Error'},
|
||||
)
|
||||
78
src/presentation/messaging/code.py
Normal file
78
src/presentation/messaging/code.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from fastapi import Depends
|
||||
from faststream.rabbit.fastapi import RabbitRouter, RabbitMessage
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.mail import TemplateRenderer, get_email_sender, get_renderer
|
||||
from src.infrastructure.mail.sender import EmailSender
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from src.infrastructure.rabbit.broker import email_code_queue
|
||||
|
||||
|
||||
code_broker = RabbitRouter(settings.RABBIT_URL)
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
trace_id: Optional[str] = None
|
||||
source: str
|
||||
timestamp: datetime
|
||||
message_id: str
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
email: EmailStr
|
||||
code: str = Field(min_length=1, max_length=64)
|
||||
ttl_seconds: int = Field(gt=0, lt=24 * 3600)
|
||||
|
||||
|
||||
class LoginCodeCreated(BaseModel):
|
||||
event: Literal["login", "registration", "bank_details_update"]
|
||||
payload: Payload
|
||||
metadata: Metadata
|
||||
|
||||
|
||||
@code_broker.subscriber(email_code_queue)
|
||||
async def consume_email_code(
|
||||
msg_body: LoginCodeCreated,
|
||||
message: RabbitMessage,
|
||||
sender: EmailSender = Depends(get_email_sender),
|
||||
renderer: TemplateRenderer = Depends(get_renderer),
|
||||
):
|
||||
trace_id = (
|
||||
(message.headers or {}).get("trace_id")
|
||||
or message.correlation_id
|
||||
or msg_body.metadata.trace_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"received event={msg_body.event} "
|
||||
f"email={msg_body.payload.email} "
|
||||
f"ttl={msg_body.payload.ttl_seconds} "
|
||||
f"trace_id={trace_id}"
|
||||
)
|
||||
|
||||
html = renderer.render(
|
||||
"email_code.html",
|
||||
subject="Код подтверждения",
|
||||
code=msg_body.payload.code,
|
||||
ttl_minutes=msg_body.payload.ttl_seconds // 60,
|
||||
brand="Bitforce",
|
||||
trace_id=trace_id,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
text = renderer.render(
|
||||
"email_code.txt",
|
||||
code=msg_body.payload.code,
|
||||
ttl_minutes=msg_body.payload.ttl_seconds // 60,
|
||||
brand="Bitforce",
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
await sender.send(
|
||||
to=msg_body.payload.email,
|
||||
subject="Код подтверждения",
|
||||
body=html,
|
||||
plain=text,
|
||||
)
|
||||
2
src/presentation/middleware/__init__.py
Normal file
2
src/presentation/middleware/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.middleware.trace_id import TraceIDMiddleware
|
||||
from src.presentation.middleware.security_headers import SecurityHeadersMiddleware
|
||||
51
src/presentation/middleware/security_headers.py
Normal file
51
src/presentation/middleware/security_headers.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
*,
|
||||
hsts: bool = True,
|
||||
hsts_max_age: int = 31536000, # 1 год
|
||||
hsts_include_subdomains: bool = True,
|
||||
hsts_preload: bool = False,
|
||||
frame_options: str = 'DENY', # или 'SAMEORIGIN'
|
||||
referrer_policy: str = 'strict-origin-when-cross-origin',
|
||||
content_security_policy: str | None = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.hsts = hsts
|
||||
self.hsts_max_age = hsts_max_age
|
||||
self.hsts_include_subdomains = hsts_include_subdomains
|
||||
self.hsts_preload = hsts_preload
|
||||
self.frame_options = frame_options
|
||||
self.referrer_policy = referrer_policy
|
||||
self.csp = content_security_policy
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
response: Response = await call_next(request)
|
||||
|
||||
if request.url.path in ('/docs', '/redoc', '/openapi.json'):
|
||||
return response
|
||||
|
||||
if self.hsts and request.url.scheme == 'https':
|
||||
hsts = f'max-age={self.hsts_max_age}'
|
||||
if self.hsts_include_subdomains:
|
||||
hsts += '; includeSubDomains'
|
||||
if self.hsts_preload:
|
||||
hsts += '; preload'
|
||||
response.headers['Strict-Transport-Security'] = hsts
|
||||
|
||||
response.headers['X-Content-Type-Options'] = 'nosniff'
|
||||
|
||||
response.headers['X-Frame-Options'] = self.frame_options
|
||||
|
||||
response.headers['Referrer-Policy'] = self.referrer_policy
|
||||
|
||||
if self.csp:
|
||||
response.headers['Content-Security-Policy'] = self.csp
|
||||
|
||||
return response
|
||||
135
src/presentation/middleware/trace_id.py
Normal file
135
src/presentation/middleware/trace_id.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
from contextvars import Token
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
from ulid import ULID
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.context_vars import trace_id_var
|
||||
|
||||
|
||||
class TraceIDMiddleware:
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
logger: ILogger,
|
||||
response_header_name: str = "X-Trace-ID",
|
||||
attach_response_header: bool = True,
|
||||
) -> None:
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.response_header_name = response_header_name
|
||||
self.attach_response_header = attach_response_header
|
||||
|
||||
def _is_excluded(self, path: str) -> bool:
|
||||
return any(path == p or path.startswith(p.rstrip("/") + "/") for p in settings.EXCLUDED_PATHS)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = Request(scope)
|
||||
|
||||
if self._is_excluded(request.url.path):
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
trace_id = request.headers.get("X-Trace-ID") or request.headers.get("X-Request-ID")
|
||||
if not trace_id:
|
||||
trace_id = str(ULID())
|
||||
|
||||
request.state.trace_id = trace_id
|
||||
|
||||
token: Token = trace_id_var.set(trace_id)
|
||||
|
||||
self.logger.debug(f"Request started: {request.method} {request.url} - TraceID: {trace_id}")
|
||||
|
||||
status_code_holder: dict[str, Optional[int]] = {"status": None}
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
status_code_holder["status"] = int(message["status"])
|
||||
|
||||
if self.attach_response_header:
|
||||
headers = list(message.get("headers", []))
|
||||
headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
finally:
|
||||
status = status_code_holder["status"]
|
||||
status_part = f"{status}" if status is not None else "unknown"
|
||||
self.logger.debug(
|
||||
f"Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}"
|
||||
)
|
||||
trace_id_var.reset(token)
|
||||
|
||||
|
||||
# from __future__ import annotations
|
||||
# from typing import Optional
|
||||
# from starlette.requests import Request
|
||||
# from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
# from ulid import ULID
|
||||
# from src.application.contracts import ILogger
|
||||
# from src.infrastructure.config.settings import settings
|
||||
#
|
||||
#
|
||||
# class TraceIDMiddleware:
|
||||
# def __init__(
|
||||
# self,
|
||||
# app: ASGIApp,
|
||||
# logger: ILogger,
|
||||
# response_header_name: str = 'X-Trace-ID',
|
||||
# attach_response_header: bool = True,
|
||||
# ) -> None:
|
||||
# self.app = app
|
||||
# self.logger = logger
|
||||
# self.response_header_name = response_header_name
|
||||
# self.attach_response_header = attach_response_header
|
||||
#
|
||||
# def _is_excluded(self, path: str) -> bool:
|
||||
# return any(path == p or path.startswith(p.rstrip('/') + '/') for p in settings.EXCLUDED_PATHS)
|
||||
#
|
||||
# async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# if scope['type'] != 'http':
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# request = Request(scope)
|
||||
#
|
||||
# if self._is_excluded(request.url.path):
|
||||
# await self.app(scope, receive, send)
|
||||
# return
|
||||
#
|
||||
# trace_id = request.headers.get('X-Trace-ID') or request.headers.get('X-Request-ID')
|
||||
# if not trace_id:
|
||||
# trace_id = str(ULID())
|
||||
#
|
||||
# request.state.trace_id = trace_id
|
||||
# self.logger.set_trace_id(trace_id)
|
||||
#
|
||||
# self.logger.debug(f'Request started: {request.method} {request.url} - TraceID: {trace_id}')
|
||||
#
|
||||
# status_code_holder: dict[str, Optional[int]] = {'status': None}
|
||||
#
|
||||
# async def send_wrapper(message: Message) -> None:
|
||||
# if message['type'] == 'http.response.start':
|
||||
# status_code_holder['status'] = int(message['status'])
|
||||
#
|
||||
# if self.attach_response_header:
|
||||
# headers = list(message.get('headers', []))
|
||||
# headers.append((self.response_header_name.lower().encode(), trace_id.encode()))
|
||||
# message['headers'] = headers
|
||||
# await send(message)
|
||||
#
|
||||
# try:
|
||||
# await self.app(scope, receive, send_wrapper)
|
||||
# finally:
|
||||
# status = status_code_holder['status']
|
||||
# status_part = f'{status}' if status is not None else 'unknown'
|
||||
# self.logger.debug(f'Request finished: {request.method} {request.url} - TraceID: {trace_id} - Status: {status_part}')
|
||||
# self.logger.clear_trace_id()
|
||||
Reference in New Issue
Block a user