diff --git a/src/command/__init__.py b/src/application/command/__init__.py similarity index 100% rename from src/command/__init__.py rename to src/application/command/__init__.py diff --git a/src/command/create_order_command.py b/src/application/command/create_order_command.py similarity index 100% rename from src/command/create_order_command.py rename to src/application/command/create_order_command.py diff --git a/src/infrastructure/config/settings.py b/src/infrastructure/config/settings.py index 18e9592..7ee4f5a 100644 --- a/src/infrastructure/config/settings.py +++ b/src/infrastructure/config/settings.py @@ -4,19 +4,34 @@ from functools import lru_cache from typing import List, Literal import os from dotenv import load_dotenv, find_dotenv -from pydantic import Field, model_validator +from pydantic import AliasChoices,Field,field_validator,model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from src.infrastructure.vault import create_hvac_client, read_kv2_secret +from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret env_file = find_dotenv(".env") if env_file: load_dotenv(env_file) +def normalize_vault_base_url(raw: str) -> str: + u = raw.strip().rstrip('/') + if not u: + return raw.strip() + if '://' not in u: + return f'https://{u}' + return u + + class Settings(BaseSettings): - VAULT_ADDR: str = Field(default="http://localhost:8200") - VAULT_TOKEN: str = Field(..., description="Vault token is required") - VAULT_MOUNT_POINT: str = Field(default="secrets") + VAULT_ADDR: str = Field(default='http://localhost:8200') + VAULT_ROLE_ID: str = Field(...,description='AppRole role_id') + VAULT_SECRET_ID: str = Field( + ..., + description='AppRole secret_id', + validation_alias=AliasChoices('VAULT_SECRET_ID','VAULT_SECRET_TOKEN'), + ) + VAULT_NAMESPACE: str | None = Field(default=None) + VAULT_MOUNT_POINT: str = Field(default='secrets') VAULT_JWT_KID_PATH: str = "jwt/kid" VAULT_JWT_KIDS_PREFIX: str = "jwt/kids" @@ -77,51 +92,110 @@ class Settings(BaseSettings): env_file_encoding="utf-8", case_sensitive=True, extra="ignore", + populate_by_name=True, ) + @field_validator('VAULT_ADDR',mode='before') + @classmethod + def vault_addr_scheme(cls, v): + if v is None or not isinstance(v,str): + return v + return normalize_vault_base_url(v) + @model_validator(mode="before") @classmethod def load_from_vault(cls, data: dict): - addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200" - token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN") - mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets" + if not isinstance(data,dict): + return data + addr_raw = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200' + addr = normalize_vault_base_url(addr_raw) + data['VAULT_ADDR'] = addr + role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID') + secret_id = ( + data.get('VAULT_SECRET_ID') + or data.get('VAULT_SECRET_TOKEN') + or os.getenv('VAULT_SECRET_ID') + or os.getenv('VAULT_SECRET_TOKEN') + ) + namespace = data.get('VAULT_NAMESPACE') + if namespace is None: + namespace = os.getenv('VAULT_NAMESPACE') + namespace = namespace if namespace else None + mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'secrets' - if not token: - raise RuntimeError("VAULT_TOKEN is required") + if not role_id or not secret_id: + raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID (or VAULT_SECRET_TOKEN) are required for Vault AppRole') - client = create_hvac_client(url=addr, token=token, timeout=5) + data['VAULT_ROLE_ID'] = str(role_id).strip() + data['VAULT_SECRET_ID'] = str(secret_id).strip() - def safe_read(path: str) -> dict: + client = create_hvac_client_from_approle( + url=addr, + role_id=role_id, + secret_id=secret_id, + namespace=namespace, + timeout=5, + ) + + def read_secret(path: str) -> dict: + return read_kv2_secret(client=client,mount_point=mount,path=path) + + def read_secret_optional(path: str) -> dict: try: - return read_kv2_secret(client=client, mount_point=mount, path=path) + return read_secret(path) except Exception: return {} - database = safe_read("database") - rabbitmq = safe_read("rabbitmq") - csrf = safe_read("csrf") + database = read_secret('database') + csrf = read_secret_optional('csrf') + rabbitmq = read_secret_optional('rabbitmq') - if database: - required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"] - missing = [k for k in required if k not in database] - if missing: - raise RuntimeError(f"Vault database secret missing keys {missing}") + db_ci = {str(k).lower(): v for k, v in database.items()} - data["DATABASE_HOST"] = database["HOST"] - data["DATABASE_PORT"] = database["PORT"] - data["DATABASE_NAME"] = database["NAME"] - data["DATABASE_USER"] = database["USER"] - data["DATABASE_PASSWORD"] = database["PASSWORD"] + def db_nonempty(key: str) -> bool: + v = db_ci.get(key) + if v is None: + return False + if isinstance(v,str) and not v.strip(): + return False + return True - if rabbitmq: - data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST")) - data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT")) - data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER")) - data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD")) - data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST")) + required_db = ['host','name','user','password','port'] + missing_db = [k for k in required_db if not db_nonempty(k)] + if missing_db: + raise RuntimeError(f'Vault secret database missing non-empty keys: {missing_db}') + + data['DATABASE_HOST'] = str(db_ci['host']).strip() + data['DATABASE_PORT'] = int(db_ci['port']) + data['DATABASE_NAME'] = str(db_ci['name']).strip() + data['DATABASE_USER'] = str(db_ci['user']).strip() + data['DATABASE_PASSWORD'] = str(db_ci['password']).strip() if csrf: - data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY")) + csrf_secret = None + for entry_key, entry_val in csrf.items(): + if str(entry_key).lower() == 'key' and entry_val is not None and str(entry_val).strip(): + csrf_secret = str(entry_val).strip() + break + if csrf_secret: + data['CSRF_SECRET_KEY'] = csrf_secret + + if rabbitmq: + r_ci = {str(k).lower(): v for k, v in rabbitmq.items()} + + def rb_set(field: str, env_key: str, *, as_int: bool = False) -> None: + v = r_ci.get(field) + if v is None: + return + if isinstance(v,str) and not v.strip(): + return + data[env_key] = int(v) if as_int else str(v).strip() + + rb_set('host','RABBIT_HOST') + rb_set('port','RABBIT_PORT',as_int=True) + rb_set('user','RABBIT_USER') + rb_set('password','RABBIT_PASSWORD') + rb_set('vhost','RABBIT_VHOST') return data diff --git a/src/infrastructure/vault/__init__.py b/src/infrastructure/vault/__init__.py index 5206af7..4c12dd3 100644 --- a/src/infrastructure/vault/__init__.py +++ b/src/infrastructure/vault/__init__.py @@ -1,3 +1,3 @@ -from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client +from src.infrastructure.vault.utils import read_kv2_secret,create_hvac_client_from_approle from src.infrastructure.vault.keys import JwtKeyStore from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler \ No newline at end of file diff --git a/src/infrastructure/vault/keys.py b/src/infrastructure/vault/keys.py index 6e12f76..fefb0b1 100644 --- a/src/infrastructure/vault/keys.py +++ b/src/infrastructure/vault/keys.py @@ -3,7 +3,7 @@ import asyncio from datetime import datetime, timezone from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey from src.application.domain.exceptions import ApplicationException -from src.infrastructure.vault import create_hvac_client, read_kv2_secret +from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret class JwtKeyStore: @@ -19,7 +19,9 @@ class JwtKeyStore: self, *, vault_addr: str, - vault_token: str, + vault_role_id: str, + vault_secret_id: str, + vault_namespace: str | None, mount_point: str, kid_path: str = 'jwt/kid', kids_prefix: str = 'jwt/kids', @@ -30,7 +32,9 @@ class JwtKeyStore: return self._vault_addr = vault_addr - self._vault_token = vault_token + self._vault_role_id = vault_role_id + self._vault_secret_id = vault_secret_id + self._vault_namespace = vault_namespace self._timeout = timeout_seconds self._mount = mount_point @@ -52,7 +56,13 @@ class JwtKeyStore: return cls._instance def _read_keyset_sync(self) -> JwtPublicKeySet: - client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout) + client = create_hvac_client_from_approle( + url=self._vault_addr, + role_id=self._vault_role_id, + secret_id=self._vault_secret_id, + namespace=self._vault_namespace, + timeout=self._timeout, + ) kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path) active_kid = kids.get('active') diff --git a/src/infrastructure/vault/utils.py b/src/infrastructure/vault/utils.py index 27b3ba9..c0dee85 100644 --- a/src/infrastructure/vault/utils.py +++ b/src/infrastructure/vault/utils.py @@ -2,10 +2,23 @@ from __future__ import annotations import hvac -def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client: - client = hvac.Client(url=url, token=token, timeout=timeout) +def create_hvac_client_from_approle( + *, + url: str, + role_id: str, + secret_id: str, + namespace: str | None = None, + timeout: int = 5, +) -> hvac.Client: + kwargs: dict = {'url': url, 'timeout': timeout} + if namespace: + kwargs['namespace'] = namespace + client = hvac.Client(**kwargs) + client.auth.approle.login(role_id=role_id, secret_id=secret_id) if not client.is_authenticated(): - raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN") + raise RuntimeError( + 'Vault AppRole authentication failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID' + ) return client diff --git a/src/main.py b/src/main.py index d4286d4..bbf186a 100644 --- a/src/main.py +++ b/src/main.py @@ -43,7 +43,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: jwt_store = JwtKeyStore( vault_addr=settings.VAULT_ADDR, - vault_token=settings.VAULT_TOKEN, + vault_role_id=settings.VAULT_ROLE_ID, + vault_secret_id=settings.VAULT_SECRET_ID, + vault_namespace=settings.VAULT_NAMESPACE, mount_point=settings.VAULT_MOUNT_POINT, kid_path=settings.VAULT_JWT_KID_PATH, kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,