feat: approle vault

This commit is contained in:
2026-04-22 11:40:25 +03:00
parent 00e601c21a
commit bea79634b5
7 changed files with 141 additions and 42 deletions

View File

@@ -4,19 +4,34 @@ from functools import lru_cache
from typing import List, Literal from typing import List, Literal
import os import os
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
from pydantic import Field, model_validator from pydantic import AliasChoices,Field,field_validator,model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from src.infrastructure.vault import create_hvac_client, read_kv2_secret from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret
env_file = find_dotenv(".env") env_file = find_dotenv(".env")
if env_file: if env_file:
load_dotenv(env_file) load_dotenv(env_file)
def normalize_vault_base_url(raw: str) -> str:
u = raw.strip().rstrip('/')
if not u:
return raw.strip()
if '://' not in u:
return f'https://{u}'
return u
class Settings(BaseSettings): class Settings(BaseSettings):
VAULT_ADDR: str = Field(default="http://localhost:8200") VAULT_ADDR: str = Field(default='http://localhost:8200')
VAULT_TOKEN: str = Field(..., description="Vault token is required") VAULT_ROLE_ID: str = Field(...,description='AppRole role_id')
VAULT_MOUNT_POINT: str = Field(default="secrets") VAULT_SECRET_ID: str = Field(
...,
description='AppRole secret_id',
validation_alias=AliasChoices('VAULT_SECRET_ID','VAULT_SECRET_TOKEN'),
)
VAULT_NAMESPACE: str | None = Field(default=None)
VAULT_MOUNT_POINT: str = Field(default='secrets')
VAULT_JWT_KID_PATH: str = "jwt/kid" VAULT_JWT_KID_PATH: str = "jwt/kid"
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids" VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
@@ -77,51 +92,110 @@ class Settings(BaseSettings):
env_file_encoding="utf-8", env_file_encoding="utf-8",
case_sensitive=True, case_sensitive=True,
extra="ignore", extra="ignore",
populate_by_name=True,
) )
@field_validator('VAULT_ADDR',mode='before')
@classmethod
def vault_addr_scheme(cls, v):
if v is None or not isinstance(v,str):
return v
return normalize_vault_base_url(v)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def load_from_vault(cls, data: dict): def load_from_vault(cls, data: dict):
addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200" if not isinstance(data,dict):
token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN") return data
mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets" addr_raw = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
addr = normalize_vault_base_url(addr_raw)
data['VAULT_ADDR'] = addr
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
secret_id = (
data.get('VAULT_SECRET_ID')
or data.get('VAULT_SECRET_TOKEN')
or os.getenv('VAULT_SECRET_ID')
or os.getenv('VAULT_SECRET_TOKEN')
)
namespace = data.get('VAULT_NAMESPACE')
if namespace is None:
namespace = os.getenv('VAULT_NAMESPACE')
namespace = namespace if namespace else None
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'secrets'
if not token: if not role_id or not secret_id:
raise RuntimeError("VAULT_TOKEN is required") raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID (or VAULT_SECRET_TOKEN) are required for Vault AppRole')
client = create_hvac_client(url=addr, token=token, timeout=5) data['VAULT_ROLE_ID'] = str(role_id).strip()
data['VAULT_SECRET_ID'] = str(secret_id).strip()
def safe_read(path: str) -> dict: client = create_hvac_client_from_approle(
try: url=addr,
role_id=role_id,
secret_id=secret_id,
namespace=namespace,
timeout=5,
)
def read_secret(path: str) -> dict:
return read_kv2_secret(client=client,mount_point=mount,path=path) return read_kv2_secret(client=client,mount_point=mount,path=path)
def read_secret_optional(path: str) -> dict:
try:
return read_secret(path)
except Exception: except Exception:
return {} return {}
database = safe_read("database") database = read_secret('database')
rabbitmq = safe_read("rabbitmq") csrf = read_secret_optional('csrf')
csrf = safe_read("csrf") rabbitmq = read_secret_optional('rabbitmq')
if database: db_ci = {str(k).lower(): v for k, v in database.items()}
required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"]
missing = [k for k in required if k not in database]
if missing:
raise RuntimeError(f"Vault database secret missing keys {missing}")
data["DATABASE_HOST"] = database["HOST"] def db_nonempty(key: str) -> bool:
data["DATABASE_PORT"] = database["PORT"] v = db_ci.get(key)
data["DATABASE_NAME"] = database["NAME"] if v is None:
data["DATABASE_USER"] = database["USER"] return False
data["DATABASE_PASSWORD"] = database["PASSWORD"] if isinstance(v,str) and not v.strip():
return False
return True
if rabbitmq: required_db = ['host','name','user','password','port']
data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST")) missing_db = [k for k in required_db if not db_nonempty(k)]
data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT")) if missing_db:
data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER")) raise RuntimeError(f'Vault secret database missing non-empty keys: {missing_db}')
data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD"))
data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST")) data['DATABASE_HOST'] = str(db_ci['host']).strip()
data['DATABASE_PORT'] = int(db_ci['port'])
data['DATABASE_NAME'] = str(db_ci['name']).strip()
data['DATABASE_USER'] = str(db_ci['user']).strip()
data['DATABASE_PASSWORD'] = str(db_ci['password']).strip()
if csrf: if csrf:
data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY")) csrf_secret = None
for entry_key, entry_val in csrf.items():
if str(entry_key).lower() == 'key' and entry_val is not None and str(entry_val).strip():
csrf_secret = str(entry_val).strip()
break
if csrf_secret:
data['CSRF_SECRET_KEY'] = csrf_secret
if rabbitmq:
r_ci = {str(k).lower(): v for k, v in rabbitmq.items()}
def rb_set(field: str, env_key: str, *, as_int: bool = False) -> None:
v = r_ci.get(field)
if v is None:
return
if isinstance(v,str) and not v.strip():
return
data[env_key] = int(v) if as_int else str(v).strip()
rb_set('host','RABBIT_HOST')
rb_set('port','RABBIT_PORT',as_int=True)
rb_set('user','RABBIT_USER')
rb_set('password','RABBIT_PASSWORD')
rb_set('vhost','RABBIT_VHOST')
return data return data

View File

@@ -1,3 +1,3 @@
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client from src.infrastructure.vault.utils import read_kv2_secret,create_hvac_client_from_approle
from src.infrastructure.vault.keys import JwtKeyStore from src.infrastructure.vault.keys import JwtKeyStore
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler

View File

@@ -3,7 +3,7 @@ import asyncio
from datetime import datetime, timezone from datetime import datetime, timezone
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException
from src.infrastructure.vault import create_hvac_client, read_kv2_secret from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret
class JwtKeyStore: class JwtKeyStore:
@@ -19,7 +19,9 @@ class JwtKeyStore:
self, self,
*, *,
vault_addr: str, vault_addr: str,
vault_token: str, vault_role_id: str,
vault_secret_id: str,
vault_namespace: str | None,
mount_point: str, mount_point: str,
kid_path: str = 'jwt/kid', kid_path: str = 'jwt/kid',
kids_prefix: str = 'jwt/kids', kids_prefix: str = 'jwt/kids',
@@ -30,7 +32,9 @@ class JwtKeyStore:
return return
self._vault_addr = vault_addr self._vault_addr = vault_addr
self._vault_token = vault_token self._vault_role_id = vault_role_id
self._vault_secret_id = vault_secret_id
self._vault_namespace = vault_namespace
self._timeout = timeout_seconds self._timeout = timeout_seconds
self._mount = mount_point self._mount = mount_point
@@ -52,7 +56,13 @@ class JwtKeyStore:
return cls._instance return cls._instance
def _read_keyset_sync(self) -> JwtPublicKeySet: def _read_keyset_sync(self) -> JwtPublicKeySet:
client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout) client = create_hvac_client_from_approle(
url=self._vault_addr,
role_id=self._vault_role_id,
secret_id=self._vault_secret_id,
namespace=self._vault_namespace,
timeout=self._timeout,
)
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path) kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
active_kid = kids.get('active') active_kid = kids.get('active')

View File

@@ -2,10 +2,23 @@ from __future__ import annotations
import hvac import hvac
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client: def create_hvac_client_from_approle(
client = hvac.Client(url=url, token=token, timeout=timeout) *,
url: str,
role_id: str,
secret_id: str,
namespace: str | None = None,
timeout: int = 5,
) -> hvac.Client:
kwargs: dict = {'url': url, 'timeout': timeout}
if namespace:
kwargs['namespace'] = namespace
client = hvac.Client(**kwargs)
client.auth.approle.login(role_id=role_id, secret_id=secret_id)
if not client.is_authenticated(): if not client.is_authenticated():
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN") raise RuntimeError(
'Vault AppRole authentication failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID'
)
return client return client

View File

@@ -43,7 +43,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
jwt_store = JwtKeyStore( jwt_store = JwtKeyStore(
vault_addr=settings.VAULT_ADDR, vault_addr=settings.VAULT_ADDR,
vault_token=settings.VAULT_TOKEN, vault_role_id=settings.VAULT_ROLE_ID,
vault_secret_id=settings.VAULT_SECRET_ID,
vault_namespace=settings.VAULT_NAMESPACE,
mount_point=settings.VAULT_MOUNT_POINT, mount_point=settings.VAULT_MOUNT_POINT,
kid_path=settings.VAULT_JWT_KID_PATH, kid_path=settings.VAULT_JWT_KID_PATH,
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX, kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,