feat: approle vault
This commit is contained in:
@@ -4,19 +4,34 @@ from functools import lru_cache
|
|||||||
from typing import List, Literal
|
from typing import List, Literal
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv, find_dotenv
|
from dotenv import load_dotenv, find_dotenv
|
||||||
from pydantic import Field, model_validator
|
from pydantic import AliasChoices,Field,field_validator,model_validator
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret
|
||||||
|
|
||||||
env_file = find_dotenv(".env")
|
env_file = find_dotenv(".env")
|
||||||
if env_file:
|
if env_file:
|
||||||
load_dotenv(env_file)
|
load_dotenv(env_file)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_vault_base_url(raw: str) -> str:
|
||||||
|
u = raw.strip().rstrip('/')
|
||||||
|
if not u:
|
||||||
|
return raw.strip()
|
||||||
|
if '://' not in u:
|
||||||
|
return f'https://{u}'
|
||||||
|
return u
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
VAULT_ADDR: str = Field(default="http://localhost:8200")
|
VAULT_ADDR: str = Field(default='http://localhost:8200')
|
||||||
VAULT_TOKEN: str = Field(..., description="Vault token is required")
|
VAULT_ROLE_ID: str = Field(...,description='AppRole role_id')
|
||||||
VAULT_MOUNT_POINT: str = Field(default="secrets")
|
VAULT_SECRET_ID: str = Field(
|
||||||
|
...,
|
||||||
|
description='AppRole secret_id',
|
||||||
|
validation_alias=AliasChoices('VAULT_SECRET_ID','VAULT_SECRET_TOKEN'),
|
||||||
|
)
|
||||||
|
VAULT_NAMESPACE: str | None = Field(default=None)
|
||||||
|
VAULT_MOUNT_POINT: str = Field(default='secrets')
|
||||||
|
|
||||||
VAULT_JWT_KID_PATH: str = "jwt/kid"
|
VAULT_JWT_KID_PATH: str = "jwt/kid"
|
||||||
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
|
VAULT_JWT_KIDS_PREFIX: str = "jwt/kids"
|
||||||
@@ -77,51 +92,110 @@ class Settings(BaseSettings):
|
|||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
case_sensitive=True,
|
case_sensitive=True,
|
||||||
extra="ignore",
|
extra="ignore",
|
||||||
|
populate_by_name=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator('VAULT_ADDR',mode='before')
|
||||||
|
@classmethod
|
||||||
|
def vault_addr_scheme(cls, v):
|
||||||
|
if v is None or not isinstance(v,str):
|
||||||
|
return v
|
||||||
|
return normalize_vault_base_url(v)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_vault(cls, data: dict):
|
def load_from_vault(cls, data: dict):
|
||||||
addr = data.get("VAULT_ADDR") or os.getenv("VAULT_ADDR") or "http://localhost:8200"
|
if not isinstance(data,dict):
|
||||||
token = data.get("VAULT_TOKEN") or os.getenv("VAULT_TOKEN")
|
return data
|
||||||
mount = data.get("VAULT_MOUNT_POINT") or os.getenv("VAULT_MOUNT_POINT") or "secrets"
|
addr_raw = data.get('VAULT_ADDR') or os.getenv('VAULT_ADDR') or 'http://localhost:8200'
|
||||||
|
addr = normalize_vault_base_url(addr_raw)
|
||||||
|
data['VAULT_ADDR'] = addr
|
||||||
|
role_id = data.get('VAULT_ROLE_ID') or os.getenv('VAULT_ROLE_ID')
|
||||||
|
secret_id = (
|
||||||
|
data.get('VAULT_SECRET_ID')
|
||||||
|
or data.get('VAULT_SECRET_TOKEN')
|
||||||
|
or os.getenv('VAULT_SECRET_ID')
|
||||||
|
or os.getenv('VAULT_SECRET_TOKEN')
|
||||||
|
)
|
||||||
|
namespace = data.get('VAULT_NAMESPACE')
|
||||||
|
if namespace is None:
|
||||||
|
namespace = os.getenv('VAULT_NAMESPACE')
|
||||||
|
namespace = namespace if namespace else None
|
||||||
|
mount = data.get('VAULT_MOUNT_POINT') or os.getenv('VAULT_MOUNT_POINT') or 'secrets'
|
||||||
|
|
||||||
if not token:
|
if not role_id or not secret_id:
|
||||||
raise RuntimeError("VAULT_TOKEN is required")
|
raise RuntimeError('VAULT_ROLE_ID and VAULT_SECRET_ID (or VAULT_SECRET_TOKEN) are required for Vault AppRole')
|
||||||
|
|
||||||
client = create_hvac_client(url=addr, token=token, timeout=5)
|
data['VAULT_ROLE_ID'] = str(role_id).strip()
|
||||||
|
data['VAULT_SECRET_ID'] = str(secret_id).strip()
|
||||||
|
|
||||||
def safe_read(path: str) -> dict:
|
client = create_hvac_client_from_approle(
|
||||||
try:
|
url=addr,
|
||||||
|
role_id=role_id,
|
||||||
|
secret_id=secret_id,
|
||||||
|
namespace=namespace,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_secret(path: str) -> dict:
|
||||||
return read_kv2_secret(client=client,mount_point=mount,path=path)
|
return read_kv2_secret(client=client,mount_point=mount,path=path)
|
||||||
|
|
||||||
|
def read_secret_optional(path: str) -> dict:
|
||||||
|
try:
|
||||||
|
return read_secret(path)
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
database = safe_read("database")
|
database = read_secret('database')
|
||||||
rabbitmq = safe_read("rabbitmq")
|
csrf = read_secret_optional('csrf')
|
||||||
csrf = safe_read("csrf")
|
rabbitmq = read_secret_optional('rabbitmq')
|
||||||
|
|
||||||
if database:
|
db_ci = {str(k).lower(): v for k, v in database.items()}
|
||||||
required = ["HOST", "NAME", "USER", "PASSWORD", "PORT"]
|
|
||||||
missing = [k for k in required if k not in database]
|
|
||||||
if missing:
|
|
||||||
raise RuntimeError(f"Vault database secret missing keys {missing}")
|
|
||||||
|
|
||||||
data["DATABASE_HOST"] = database["HOST"]
|
def db_nonempty(key: str) -> bool:
|
||||||
data["DATABASE_PORT"] = database["PORT"]
|
v = db_ci.get(key)
|
||||||
data["DATABASE_NAME"] = database["NAME"]
|
if v is None:
|
||||||
data["DATABASE_USER"] = database["USER"]
|
return False
|
||||||
data["DATABASE_PASSWORD"] = database["PASSWORD"]
|
if isinstance(v,str) and not v.strip():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
if rabbitmq:
|
required_db = ['host','name','user','password','port']
|
||||||
data["RABBIT_HOST"] = rabbitmq.get("HOST", data.get("RABBIT_HOST"))
|
missing_db = [k for k in required_db if not db_nonempty(k)]
|
||||||
data["RABBIT_PORT"] = rabbitmq.get("PORT", data.get("RABBIT_PORT"))
|
if missing_db:
|
||||||
data["RABBIT_USER"] = rabbitmq.get("USER", data.get("RABBIT_USER"))
|
raise RuntimeError(f'Vault secret database missing non-empty keys: {missing_db}')
|
||||||
data["RABBIT_PASSWORD"] = rabbitmq.get("PASSWORD", data.get("RABBIT_PASSWORD"))
|
|
||||||
data["RABBIT_VHOST"] = rabbitmq.get("VHOST", data.get("RABBIT_VHOST"))
|
data['DATABASE_HOST'] = str(db_ci['host']).strip()
|
||||||
|
data['DATABASE_PORT'] = int(db_ci['port'])
|
||||||
|
data['DATABASE_NAME'] = str(db_ci['name']).strip()
|
||||||
|
data['DATABASE_USER'] = str(db_ci['user']).strip()
|
||||||
|
data['DATABASE_PASSWORD'] = str(db_ci['password']).strip()
|
||||||
|
|
||||||
if csrf:
|
if csrf:
|
||||||
data["CSRF_SECRET_KEY"] = csrf.get("KEY", data.get("CSRF_SECRET_KEY"))
|
csrf_secret = None
|
||||||
|
for entry_key, entry_val in csrf.items():
|
||||||
|
if str(entry_key).lower() == 'key' and entry_val is not None and str(entry_val).strip():
|
||||||
|
csrf_secret = str(entry_val).strip()
|
||||||
|
break
|
||||||
|
if csrf_secret:
|
||||||
|
data['CSRF_SECRET_KEY'] = csrf_secret
|
||||||
|
|
||||||
|
if rabbitmq:
|
||||||
|
r_ci = {str(k).lower(): v for k, v in rabbitmq.items()}
|
||||||
|
|
||||||
|
def rb_set(field: str, env_key: str, *, as_int: bool = False) -> None:
|
||||||
|
v = r_ci.get(field)
|
||||||
|
if v is None:
|
||||||
|
return
|
||||||
|
if isinstance(v,str) and not v.strip():
|
||||||
|
return
|
||||||
|
data[env_key] = int(v) if as_int else str(v).strip()
|
||||||
|
|
||||||
|
rb_set('host','RABBIT_HOST')
|
||||||
|
rb_set('port','RABBIT_PORT',as_int=True)
|
||||||
|
rb_set('user','RABBIT_USER')
|
||||||
|
rb_set('password','RABBIT_PASSWORD')
|
||||||
|
rb_set('vhost','RABBIT_VHOST')
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from src.infrastructure.vault.utils import read_kv2_secret, create_hvac_client
|
from src.infrastructure.vault.utils import read_kv2_secret,create_hvac_client_from_approle
|
||||||
from src.infrastructure.vault.keys import JwtKeyStore
|
from src.infrastructure.vault.keys import JwtKeyStore
|
||||||
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
from src.infrastructure.vault.scheduler import start_jwt_keys_scheduler
|
||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
||||||
from src.application.domain.exceptions import ApplicationException
|
from src.application.domain.exceptions import ApplicationException
|
||||||
from src.infrastructure.vault import create_hvac_client, read_kv2_secret
|
from src.infrastructure.vault import create_hvac_client_from_approle,read_kv2_secret
|
||||||
|
|
||||||
|
|
||||||
class JwtKeyStore:
|
class JwtKeyStore:
|
||||||
@@ -19,7 +19,9 @@ class JwtKeyStore:
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
vault_addr: str,
|
vault_addr: str,
|
||||||
vault_token: str,
|
vault_role_id: str,
|
||||||
|
vault_secret_id: str,
|
||||||
|
vault_namespace: str | None,
|
||||||
mount_point: str,
|
mount_point: str,
|
||||||
kid_path: str = 'jwt/kid',
|
kid_path: str = 'jwt/kid',
|
||||||
kids_prefix: str = 'jwt/kids',
|
kids_prefix: str = 'jwt/kids',
|
||||||
@@ -30,7 +32,9 @@ class JwtKeyStore:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._vault_addr = vault_addr
|
self._vault_addr = vault_addr
|
||||||
self._vault_token = vault_token
|
self._vault_role_id = vault_role_id
|
||||||
|
self._vault_secret_id = vault_secret_id
|
||||||
|
self._vault_namespace = vault_namespace
|
||||||
self._timeout = timeout_seconds
|
self._timeout = timeout_seconds
|
||||||
|
|
||||||
self._mount = mount_point
|
self._mount = mount_point
|
||||||
@@ -52,7 +56,13 @@ class JwtKeyStore:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
||||||
client = create_hvac_client(url=self._vault_addr, token=self._vault_token, timeout=self._timeout)
|
client = create_hvac_client_from_approle(
|
||||||
|
url=self._vault_addr,
|
||||||
|
role_id=self._vault_role_id,
|
||||||
|
secret_id=self._vault_secret_id,
|
||||||
|
namespace=self._vault_namespace,
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
|
||||||
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
kids = read_kv2_secret(client=client, mount_point=self._mount, path=self._kid_path)
|
||||||
active_kid = kids.get('active')
|
active_kid = kids.get('active')
|
||||||
|
|||||||
@@ -2,10 +2,23 @@ from __future__ import annotations
|
|||||||
import hvac
|
import hvac
|
||||||
|
|
||||||
|
|
||||||
def create_hvac_client(*, url: str, token: str, timeout: int = 5) -> hvac.Client:
|
def create_hvac_client_from_approle(
|
||||||
client = hvac.Client(url=url, token=token, timeout=timeout)
|
*,
|
||||||
|
url: str,
|
||||||
|
role_id: str,
|
||||||
|
secret_id: str,
|
||||||
|
namespace: str | None = None,
|
||||||
|
timeout: int = 5,
|
||||||
|
) -> hvac.Client:
|
||||||
|
kwargs: dict = {'url': url, 'timeout': timeout}
|
||||||
|
if namespace:
|
||||||
|
kwargs['namespace'] = namespace
|
||||||
|
client = hvac.Client(**kwargs)
|
||||||
|
client.auth.approle.login(role_id=role_id, secret_id=secret_id)
|
||||||
if not client.is_authenticated():
|
if not client.is_authenticated():
|
||||||
raise RuntimeError("Vault authentication failed. Check VAULT_ADDR / VAULT_TOKEN")
|
raise RuntimeError(
|
||||||
|
'Vault AppRole authentication failed. Check VAULT_ADDR, VAULT_ROLE_ID, VAULT_SECRET_ID'
|
||||||
|
)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
jwt_store = JwtKeyStore(
|
jwt_store = JwtKeyStore(
|
||||||
vault_addr=settings.VAULT_ADDR,
|
vault_addr=settings.VAULT_ADDR,
|
||||||
vault_token=settings.VAULT_TOKEN,
|
vault_role_id=settings.VAULT_ROLE_ID,
|
||||||
|
vault_secret_id=settings.VAULT_SECRET_ID,
|
||||||
|
vault_namespace=settings.VAULT_NAMESPACE,
|
||||||
mount_point=settings.VAULT_MOUNT_POINT,
|
mount_point=settings.VAULT_MOUNT_POINT,
|
||||||
kid_path=settings.VAULT_JWT_KID_PATH,
|
kid_path=settings.VAULT_JWT_KID_PATH,
|
||||||
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
kids_prefix=settings.VAULT_JWT_KIDS_PREFIX,
|
||||||
|
|||||||
Reference in New Issue
Block a user