from __future__ import annotations import asyncio from datetime import datetime, timezone from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey from src.application.domain.exceptions import ApplicationException from src.infrastructure.vault.client import VaultClient class JwtKeyStore: _instance: 'JwtKeyStore | None' = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__( self, *, vault_addr: str, vault_role_id: str, vault_secret_id: str, vault_namespace: str | None, mount_point: str, kid_path: str = 'jwt/kid', kids_prefix: str = 'jwt/kids', refresh_ttl_seconds: int = 60, ): if getattr(self, '_initialized', False): return self._vault_client = VaultClient( addr=vault_addr, role_id=vault_role_id, secret_id=vault_secret_id, namespace=vault_namespace, mount_point=mount_point, ) self._kid_path = kid_path self._kids_prefix = kids_prefix self._refresh_ttl_seconds = refresh_ttl_seconds self._lock = asyncio.Lock() self._keyset: JwtPublicKeySet | None = None self._last_refresh_at: datetime | None = None self._initialized = True @classmethod def get_instance(cls) -> 'JwtKeyStore': if cls._instance is None: raise ApplicationException(status_code=500, message='JwtKeyStore not initialized') return cls._instance def _read_keyset_sync(self) -> JwtPublicKeySet: kids = self._vault_client.read_secret(self._kid_path) active_kid = kids.get('active') previous_kid = kids.get('previous') if not active_kid: raise RuntimeError('Vault jwt/kid secret missing "active"') active = self._read_public_key_sync(str(active_kid)) previous = None if previous_kid and previous_kid != active_kid: previous = self._read_public_key_sync(str(previous_kid)) return JwtPublicKeySet(active=active, previous=previous) def _read_public_key_sync(self, kid: str) -> JwtPublicKey: data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}') pub = data.get('public_key') if not pub: raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key') return JwtPublicKey(kid=kid, public_key_pem=pub) async def refresh(self) -> JwtPublicKeySet: keyset = await asyncio.to_thread(self._read_keyset_sync) async with self._lock: self._keyset = keyset self._last_refresh_at = datetime.now(timezone.utc) return keyset async def get_public_key_for_kid(self, kid: str) -> str | None: ks = await self._get_or_refresh() return ks.public_keys_by_kid().get(kid) async def last_refresh_at(self) -> datetime | None: async with self._lock: return self._last_refresh_at async def _get_or_refresh(self) -> JwtPublicKeySet: async with self._lock: ks = self._keyset last = self._last_refresh_at if ks is None: return await self.refresh() if last is None: return await self.refresh() age = (datetime.now(timezone.utc) - last).total_seconds() if age >= self._refresh_ttl_seconds: return await self.refresh() return ks