From 0a2773488fdd86e2d75ee5d5f4f33a5a9568c367 Mon Sep 17 00:00:00 2001 From: Noloquideus Date: Tue, 12 May 2026 21:05:03 +0300 Subject: [PATCH] feat: add update --- src/infrastructure/config/settings.py | 235 +++++++++++++++++--------- src/infrastructure/vault/__init__.py | 3 +- src/infrastructure/vault/client.py | 75 ++++++++ src/infrastructure/vault/keys.py | 36 ++-- src/main.py | 8 +- 5 files changed, 252 insertions(+), 105 deletions(-) create mode 100644 src/infrastructure/vault/client.py diff --git a/src/infrastructure/config/settings.py b/src/infrastructure/config/settings.py index 18e9592..7740b34 100644 --- a/src/infrastructure/config/settings.py +++ b/src/infrastructure/config/settings.py @@ -1,32 +1,51 @@ from __future__ import annotations from functools import lru_cache -from typing import List, Literal -import os -from dotenv import load_dotenv, find_dotenv -from pydantic import Field, model_validator -from pydantic_settings import BaseSettings, SettingsConfigDict -from src.infrastructure.vault import create_hvac_client, read_kv2_secret +from typing import Any, List, Literal +from urllib.parse import quote -env_file = find_dotenv(".env") +from dotenv import find_dotenv, load_dotenv +from pydantic import Field, PrivateAttr +from pydantic_settings import BaseSettings, SettingsConfigDict + +from src.infrastructure.vault.client import VaultClient + +env_file = find_dotenv('.env') if env_file: load_dotenv(env_file) +def _as_int(value: object, default: int) -> int: + if value is None: + return default + if isinstance(value, int): + return value + return int(str(value).strip()) + + class Settings(BaseSettings): - VAULT_ADDR: str = Field(default="http://localhost:8200") - VAULT_TOKEN: str = Field(..., description="Vault token is required") - VAULT_MOUNT_POINT: str = Field(default="secrets") + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=True, extra='ignore') - VAULT_JWT_KID_PATH: str = "jwt/kid" - VAULT_JWT_KIDS_PREFIX: str = "jwt/kids" - JWT_KEYS_REFRESH_SECONDS: int = 3600 + _vault_database_secrets: dict[str, Any] = PrivateAttr(default_factory=dict) - DATABASE_HOST: str + VAULT_ADDR: str = 'https://corp.vault.elcsa.ru' + VAULT_ROLE_ID: str = '' + VAULT_SECRET_ID: str = '' + VAULT_NAMESPACE: str | None = None + VAULT_MOUNT_POINT: str = 'dev-secrets' + VAULT_DATABASE_SECRET_PATH: str = 'database' + VAULT_RABBIT_SECRET_PATH: str = 'rabbitmq' + VAULT_CSRF_SECRET_PATH: str = 'csrf' + VAULT_DOCS_SECRET_PATH: str = 'docs' + VAULT_JWT_KID_PATH: str = 'jwt/kid' + VAULT_JWT_KIDS_PREFIX: str = 'jwt/kids' + + DATABASE_URL_DIRECT: str | None = Field(default=None, validation_alias='DATABASE_URL') + DATABASE_HOST: str = '' DATABASE_PORT: int = Field(default=5432, ge=1, le=65535) - DATABASE_NAME: str - DATABASE_USER: str - DATABASE_PASSWORD: str + DATABASE_NAME: str = '' + DATABASE_USER: str = '' + DATABASE_PASSWORD: str = '' DATABASE_POOL_SIZE: int = 10 DATABASE_MAX_OVERFLOW: int = 20 @@ -35,116 +54,166 @@ class Settings(BaseSettings): DATABASE_ECHO: bool = False CSRF_SECRET_KEY: str = Field( - default="change-me-change-me-change-me-change-me", + default='change-me-change-me-change-me-change-me', min_length=32, ) CSRF_COOKIE_SECURE: bool = False CSRF_COOKIE_HTTPONLY: bool = True - CSRF_COOKIE_SAMESITE: Literal["Lax", "Strict", "None"] = "Lax" - CSRF_COOKIE_PATH: str = "/" + CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax' + CSRF_COOKIE_PATH: str = '/' CSRF_COOKIE_DOMAIN: str | None = None - DOCS_USERNAME: str = "admin" - DOCS_PASSWORD: str = "admin" + DOCS_USERNAME: str = 'admin' + DOCS_PASSWORD: str = 'admin' JWT_ACCESS_TTL_SECONDS: int = 15 * 60 JWT_REFRESH_TTL_SECONDS: int = 30 * 24 * 60 * 60 JWT_ISSUER: str | None = None JWT_AUDIENCE: str | None = None - JWT_ALGORITHM: str = "RS256" + JWT_ALGORITHM: str = 'RS256' + JWT_KEYS_REFRESH_SECONDS: int = 3600 - REDIS_HOST: str = "localhost" + REDIS_HOST: str = 'keydb' REDIS_PORT: int = 6379 REDIS_PASSWORD: str | None = None REDIS_DB: int = 0 - RABBIT_HOST: str = "localhost" + RABBIT_HOST: str = 'localhost' RABBIT_PORT: int = 5672 - RABBIT_USER: str = "guest" - RABBIT_PASSWORD: str = "guest" - RABBIT_VHOST: str = "/" + RABBIT_USER: str = 'guest' + RABBIT_PASSWORD: str = 'guest' + RABBIT_VHOST: str = '/' RABBIT_PUBLISH_PERSIST: bool = True RABBIT_CONNECT_TIMEOUT: int = 5 - RABBIT_EMAIL_CODE_QUEUE: str = "email.verification_code" + RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code' - LOG_LEVEL: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" - LOG_FORMAT: Literal["JSON", "TEXT"] = "TEXT" + LOG_LEVEL: Literal['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] = 'INFO' + LOG_FORMAT: Literal['JSON', 'TEXT'] = 'JSON' - model_config = SettingsConfigDict( - env_file=".env", - env_file_encoding="utf-8", - case_sensitive=True, - extra="ignore", - ) + def _get_vault_secret(self, secrets: dict[str, Any], *keys: str) -> str: + for key in keys: + value = secrets.get(key) + if value is not None and str(value).strip() != '': + return str(value) + return '' - @model_validator(mode="before") - @classmethod - def load_from_vault(cls, data: dict): - addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200" - token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN") - mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets" + def model_post_init(self, __context: Any) -> None: + if not self.VAULT_ROLE_ID.strip() or not self.VAULT_SECRET_ID.strip(): + if not self.DATABASE_URL: + raise ValueError( + 'Set VAULT_ROLE_ID and VAULT_SECRET_ID for Vault, or set DATABASE_URL ' + '(or DATABASE_HOST, DATABASE_USER, DATABASE_PASSWORD, DATABASE_NAME) in the environment', + ) + return - if not token: - raise RuntimeError("VAULT_TOKEN is required") + client = VaultClient( + addr=self.VAULT_ADDR, + role_id=self.VAULT_ROLE_ID, + secret_id=self.VAULT_SECRET_ID, + namespace=self.VAULT_NAMESPACE, + mount_point=self.VAULT_MOUNT_POINT, + ) - client = create_hvac_client(url=addr, token=token, timeout=5) + db = client.read_secret(self.VAULT_DATABASE_SECRET_PATH) + object.__setattr__(self, '_vault_database_secrets', db) - def safe_read(path: str) -> dict: - try: - return read_kv2_secret(client=client, mount_point=mount, path=path) - except Exception: - return {} + def kv(d: dict[str, Any], *keys: str) -> Any: + for k in keys: + if k in d and d[k] is not None: + return d[k] + return None - database = safe_read("database") - rabbitmq = safe_read("rabbitmq") - csrf = safe_read("csrf") + if kv(db, 'HOST', 'host') is not None: + object.__setattr__(self, 'DATABASE_HOST', str(kv(db, 'HOST', 'host'))) + if kv(db, 'PORT', 'port') is not None: + object.__setattr__(self, 'DATABASE_PORT', _as_int(kv(db, 'PORT', 'port'), self.DATABASE_PORT)) + if kv(db, 'NAME', 'name') is not None: + object.__setattr__(self, 'DATABASE_NAME', str(kv(db, 'NAME', 'name'))) + if kv(db, 'USER', 'user') is not None: + object.__setattr__(self, 'DATABASE_USER', str(kv(db, 'USER', 'user'))) + if kv(db, 'PASSWORD', 'password') is not None: + object.__setattr__(self, 'DATABASE_PASSWORD', str(kv(db, 'PASSWORD', 'password'))) - if database: - required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"] - missing = [k for k in required if k not in database] - if missing: - raise RuntimeError(f"Vault database secret missing keys {missing}") + rabbit = client.read_secret_optional(self.VAULT_RABBIT_SECRET_PATH) + if rabbit: + if kv(rabbit, 'HOST', 'host') is not None: + object.__setattr__(self, 'RABBIT_HOST', str(kv(rabbit, 'HOST', 'host'))) + if kv(rabbit, 'PORT', 'port') is not None: + object.__setattr__(self, 'RABBIT_PORT', _as_int(kv(rabbit, 'PORT', 'port'), self.RABBIT_PORT)) + if kv(rabbit, 'USER', 'user') is not None: + object.__setattr__(self, 'RABBIT_USER', str(kv(rabbit, 'USER', 'user'))) + if kv(rabbit, 'PASSWORD', 'password') is not None: + object.__setattr__(self, 'RABBIT_PASSWORD', str(kv(rabbit, 'PASSWORD', 'password'))) + if kv(rabbit, 'VHOST', 'vhost') is not None: + object.__setattr__(self, 'RABBIT_VHOST', str(kv(rabbit, 'VHOST', 'vhost'))) - data["DATABASE_HOST"] = database["HOST"] - data["DATABASE_PORT"] = database["PORT"] - data["DATABASE_NAME"] = database["NAME"] - data["DATABASE_USER"] = database["USER"] - data["DATABASE_PASSWORD"] = database["PASSWORD"] + csrf = client.read_secret_optional(self.VAULT_CSRF_SECRET_PATH) + if csrf and kv(csrf, 'KEY', 'key') is not None: + key = str(kv(csrf, 'KEY', 'key')) + if len(key) >= 32: + object.__setattr__(self, 'CSRF_SECRET_KEY', key) - if rabbitmq: - data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST")) - data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT")) - data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER")) - data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD")) - data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST")) + docs = client.read_secret_optional(self.VAULT_DOCS_SECRET_PATH) + if docs: + u = docs.get('DOCS_USERNAME') or docs.get('USERNAME') + p = docs.get('DOCS_PASSWORD') or docs.get('PASSWORD') + if u is not None: + object.__setattr__(self, 'DOCS_USERNAME', str(u)) + if p is not None: + object.__setattr__(self, 'DOCS_PASSWORD', str(p)) - if csrf: - data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY")) - - return data + if not self.DATABASE_URL: + raise ValueError('Database URL could not be built from Vault database secret') @property def DATABASE_URL(self) -> str: - return ( - f"postgresql+asyncpg://{self.DATABASE_USER}:{self.DATABASE_PASSWORD}" - f"@{self.DATABASE_HOST}:{self.DATABASE_PORT}/{self.DATABASE_NAME}" + direct = (self.DATABASE_URL_DIRECT or '').strip() + if direct: + return direct + + ready_url = self._get_vault_secret( + self._vault_database_secrets, + 'DATABASE_URL', + 'database_url', ) + if ready_url: + return ready_url + + host = self._get_vault_secret(self._vault_database_secrets, 'host', 'HOST') + port = self._get_vault_secret(self._vault_database_secrets, 'port', 'PORT') or str(self.DATABASE_PORT) + user = self._get_vault_secret(self._vault_database_secrets, 'user', 'USER') + password = self._get_vault_secret(self._vault_database_secrets, 'password', 'PASSWORD') + name = self._get_vault_secret(self._vault_database_secrets, 'name', 'NAME', 'database', 'DATABASE') + if not host or not user or not password or not name: + h = (self.DATABASE_HOST or '').strip() + u = (self.DATABASE_USER or '').strip() + p = (self.DATABASE_PASSWORD or '').strip() + n = (self.DATABASE_NAME or '').strip() + if h and u and p and n: + quoted_user = quote(u, safe='') + quoted_password = quote(p, safe='') + po = str(self.DATABASE_PORT) + return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{h}:{po}/{n}' + return '' + quoted_user = quote(user, safe='') + quoted_password = quote(password, safe='') + return f'postgresql+asyncpg://{quoted_user}:{quoted_password}@{host}:{port}/{name}' @property def REDIS_URL(self) -> str: - auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else "" - return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}" + auth = f':{self.REDIS_PASSWORD}@' if self.REDIS_PASSWORD else '' + return f'redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}' @property def RABBIT_URL(self) -> str: - vhost = "%2F" if self.RABBIT_VHOST == "/" else self.RABBIT_VHOST.lstrip("/") - return f"amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}" + vhost = '%2F' if self.RABBIT_VHOST == '/' else self.RABBIT_VHOST.lstrip('/') + return f'amqp://{self.RABBIT_USER}:{self.RABBIT_PASSWORD}@{self.RABBIT_HOST}:{self.RABBIT_PORT}/{vhost}' @property def EXCLUDED_PATHS(self) -> List[str]: - return ["/docs", "/redoc", "/openapi.json", "/ping", "/health"] + return ['/docs', '/redoc', '/openapi.json', '/ping', '/health'] @lru_cache(maxsize=1) @@ -152,4 +221,4 @@ def get_settings() -> Settings: return Settings() -settings = get_settings() \ No newline at end of file +settings = get_settings() diff --git a/src/infrastructure/vault/__init__.py b/src/infrastructure/vault/__init__.py index 5206af7..dc30bb1 100644 --- a/src/infrastructure/vault/__init__.py +++ b/src/infrastructure/vault/__init__.py @@ -1,3 +1,4 @@ -from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client +from src.infrastructure.vault.client import VaultClient +from src.infrastructure.vault.utils import create_hvac_client, read_kv2_secret from src.infrastructure.vault.keys import JwtKeyStore from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler \ No newline at end of file diff --git a/src/infrastructure/vault/client.py b/src/infrastructure/vault/client.py new file mode 100644 index 0000000..d474d6e --- /dev/null +++ b/src/infrastructure/vault/client.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any + +import hvac + + +def _vault_token_renew_failed(exception: Exception) -> bool: + if isinstance(exception, (hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized)): + return True + message = getattr(exception, 'message', None) or str(exception) + if isinstance(message, str): + lower = message.lower() + return 'permission denied' in lower or 'invalid token' in lower or '403' in lower + return False + + +class VaultClient: + + def __init__( + self, + *, + addr: str, + role_id: str, + secret_id: str, + namespace: str | None, + mount_point: str, + ) -> None: + self._mount_point = mount_point + self._addr = addr + self._role_id = role_id + self._secret_id = secret_id + self._namespace = namespace + self._client = hvac.Client(url=addr, namespace=namespace) + self._approle_login() + + def _approle_login(self) -> None: + self._client.auth.approle.login(role_id=self._role_id, secret_id=self._secret_id) + + def _renew_or_login(self) -> None: + try: + self._client.auth.token.renew_self() + except Exception: + self._approle_login() + + def read_secret(self, path: str) -> dict[str, Any]: + for attempt in range(2): + try: + secret = self._client.secrets.kv.v2.read_secret_version( + path=path, + mount_point=self._mount_point, + ) + return dict(secret.get('data', {}).get('data', {})) + except Exception as exc: + if attempt == 0 and _vault_token_renew_failed(exc): + self._renew_or_login() + continue + raise + + def read_secret_optional(self, path: str) -> dict[str, Any]: + if not path: + return {} + try: + return self.read_secret(path) + except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized): + return {} + result: dict[str, Any] = {} + for path in paths: + if not path: + continue + try: + result.update(self.read_secret(path)) + except (hvac.exceptions.InvalidPath, hvac.exceptions.Forbidden, hvac.exceptions.Unauthorized): + continue + return result diff --git a/src/infrastructure/vault/keys.py b/src/infrastructure/vault/keys.py index 6e12f76..473f580 100644 --- a/src/infrastructure/vault/keys.py +++ b/src/infrastructure/vault/keys.py @@ -3,7 +3,7 @@ import asyncio from datetime import datetime, timezone from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey from src.application.domain.exceptions import ApplicationException -from src.infrastructure.vault import create_hvac_client, read_kv2_secret +from src.infrastructure.vault.client import VaultClient class JwtKeyStore: @@ -19,21 +19,25 @@ class JwtKeyStore: self, *, vault_addr: str, - vault_token: str, + vault_role_id: str, + vault_secret_id: str, + vault_namespace: str | None, mount_point: str, kid_path: str = 'jwt/kid', kids_prefix: str = 'jwt/kids', - timeout_seconds: int = 5, refresh_ttl_seconds: int = 60, ): if getattr(self, '_initialized', False): return - self._vault_addr = vault_addr - self._vault_token = vault_token - self._timeout = timeout_seconds + self._vault_client = VaultClient( + addr=vault_addr, + role_id=vault_role_id, + secret_id=vault_secret_id, + namespace=vault_namespace, + mount_point=mount_point, + ) - self._mount = mount_point self._kid_path = kid_path self._kids_prefix = kids_prefix @@ -52,29 +56,23 @@ class JwtKeyStore: return cls._instance def _read_keyset_sync(self) -> JwtPublicKeySet: - client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout) - - kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path) + kids = self._vault_client.read_secret(self._kid_path) active_kid = kids.get('active') previous_kid = kids.get('previous') if not active_kid: raise RuntimeError('Vault jwt/kid secret missing "active"') - active = self._read_public_key_sync(client, str(active_kid)) + active = self._read_public_key_sync(str(active_kid)) previous = None if previous_kid and previous_kid != active_kid: - previous = self._read_public_key_sync(client, str(previous_kid)) + previous = self._read_public_key_sync(str(previous_kid)) return JwtPublicKeySet(active=active, previous=previous) - def _read_public_key_sync(self, client, kid: str) -> JwtPublicKey: - data = read_kv2_secret( - client=client, - mount_point=self._mount, - path=f'{self._kids_prefix}/{kid}', - ) + def _read_public_key_sync(self, kid: str) -> JwtPublicKey: + data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}') pub = data.get('public_key') if not pub: raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key') @@ -110,4 +108,4 @@ class JwtKeyStore: if age >= self._refresh_ttl_seconds: return await self.refresh() - return ks \ No newline at end of file + return ks diff --git a/src/main.py b/src/main.py index 471161d..04681ce 100644 --- a/src/main.py +++ b/src/main.py @@ -9,7 +9,6 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from starlette.middleware.cors import CORSMiddleware from src.application.domain.exceptions import ApplicationException from src.infrastructure.cache import create_redis_client -from src.infrastructure.config.settings import get_settings from src.infrastructure.vault import JwtKeyStore, start_jwt_keys_scheduler from src.infrastructure.utils import generate_instance_id from src.infrastructure.logger import logger @@ -42,9 +41,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.redis = create_redis_client() + if not settings.VAULT_ROLE_ID.strip() or not settings.VAULT_SECRET_ID.strip(): + raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID must be set') + jwt_store = JwtKeyStore( vault_addr=settings.VAULT_ADDR, - vault_token=settings.VAULT_TOKEN, + vault_role_id=settings.VAULT_ROLE_ID, + vault_secret_id=settings.VAULT_SECRET_ID, + vault_namespace=settings.VAULT_NAMESPACE, mount_point=settings.VAULT_MOUNT_POINT, kid_path=settings.VAULT_JWT_KID_PATH, kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,