112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
from __future__ import annotations
|
|
import asyncio
|
|
from datetime import datetime, timezone
|
|
from src.application.domain.dto import JwtPublicKeySet, JwtPublicKey
|
|
from src.application.domain.exceptions import ApplicationException
|
|
from src.infrastructure.vault.client import VaultClient
|
|
|
|
|
|
class JwtKeyStore:
|
|
|
|
_instance: 'JwtKeyStore | None' = None
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vault_addr: str,
|
|
vault_role_id: str,
|
|
vault_secret_id: str,
|
|
vault_namespace: str | None,
|
|
mount_point: str,
|
|
kid_path: str = 'jwt/kid',
|
|
kids_prefix: str = 'jwt/kids',
|
|
refresh_ttl_seconds: int = 60,
|
|
):
|
|
if getattr(self, '_initialized', False):
|
|
return
|
|
|
|
self._vault_client = VaultClient(
|
|
addr=vault_addr,
|
|
role_id=vault_role_id,
|
|
secret_id=vault_secret_id,
|
|
namespace=vault_namespace,
|
|
mount_point=mount_point,
|
|
)
|
|
|
|
self._kid_path = kid_path
|
|
self._kids_prefix = kids_prefix
|
|
|
|
self._refresh_ttl_seconds = refresh_ttl_seconds
|
|
|
|
self._lock = asyncio.Lock()
|
|
self._keyset: JwtPublicKeySet | None = None
|
|
self._last_refresh_at: datetime | None = None
|
|
|
|
self._initialized = True
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> 'JwtKeyStore':
|
|
if cls._instance is None:
|
|
raise ApplicationException(status_code=500, message='JwtKeyStore not initialized')
|
|
return cls._instance
|
|
|
|
def _read_keyset_sync(self) -> JwtPublicKeySet:
|
|
kids = self._vault_client.read_secret(self._kid_path)
|
|
active_kid = kids.get('active')
|
|
previous_kid = kids.get('previous')
|
|
|
|
if not active_kid:
|
|
raise RuntimeError('Vault jwt/kid secret missing "active"')
|
|
|
|
active = self._read_public_key_sync(str(active_kid))
|
|
|
|
previous = None
|
|
if previous_kid and previous_kid != active_kid:
|
|
previous = self._read_public_key_sync(str(previous_kid))
|
|
|
|
return JwtPublicKeySet(active=active, previous=previous)
|
|
|
|
def _read_public_key_sync(self, kid: str) -> JwtPublicKey:
|
|
data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
|
|
pub = data.get('public_key')
|
|
if not pub:
|
|
raise RuntimeError(f'Vault jwt/kids/{kid} missing public_key')
|
|
return JwtPublicKey(kid=kid, public_key_pem=pub)
|
|
|
|
async def refresh(self) -> JwtPublicKeySet:
|
|
keyset = await asyncio.to_thread(self._read_keyset_sync)
|
|
async with self._lock:
|
|
self._keyset = keyset
|
|
self._last_refresh_at = datetime.now(timezone.utc)
|
|
return keyset
|
|
|
|
async def get_public_key_for_kid(self, kid: str) -> str | None:
|
|
ks = await self._get_or_refresh()
|
|
return ks.public_keys_by_kid().get(kid)
|
|
|
|
async def last_refresh_at(self) -> datetime | None:
|
|
async with self._lock:
|
|
return self._last_refresh_at
|
|
|
|
async def _get_or_refresh(self) -> JwtPublicKeySet:
|
|
async with self._lock:
|
|
ks = self._keyset
|
|
last = self._last_refresh_at
|
|
|
|
if ks is None:
|
|
return await self.refresh()
|
|
|
|
if last is None:
|
|
return await self.refresh()
|
|
|
|
age = (datetime.now(timezone.utc) - last).total_seconds()
|
|
if age >= self._refresh_ttl_seconds:
|
|
return await self.refresh()
|
|
|
|
return ks
|