feat: update jwt logic

This commit is contained in:
2026-05-12 20:05:41 +03:00
parent 2ee33e773f
commit 927b6acc09
2 changed files with 85 additions and 25 deletions

View File

@@ -53,19 +53,17 @@ class JwtService(IJwtService):
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
raise InvalidTokenException('Invalid token algorithm')
self._logger.debug(f'JWT public key lookup started kid={kid}')
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
if not public_pem:
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
self._logger.debug(f'JWT verification keys lookup started token_kid={kid}')
verification_keys = await self._key_store.get_verification_keys(str(kid))
if not verification_keys:
self._logger.info(f'JWT verification keys miss token_kid={kid} forcing keystore refresh')
await self._key_store.refresh()
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
verification_keys = await self._key_store.get_verification_keys(str(kid))
if not public_pem:
self._logger.warning(f'JWT unknown kid kid={kid}')
if not verification_keys:
self._logger.warning(f'JWT no verification keys found token_kid={kid}')
raise InvalidTokenException('Unknown token kid')
self._logger.debug(f'JWT signature verification started kid={kid}')
options = {
'verify_signature': True,
'verify_exp': True,
@@ -80,6 +78,11 @@ class JwtService(IJwtService):
'leeway': 10,
}
payload: dict | None = None
last_error: JWTError | None = None
for key_kid,public_pem in verification_keys:
try:
self._logger.debug(f'JWT signature verification started token_kid={kid} verification_kid={key_kid}')
payload = jwt.decode(
token,
public_pem,
@@ -88,6 +91,17 @@ class JwtService(IJwtService):
issuer=settings.JWT_ISSUER or None,
options=options,
)
self._logger.info(f'JWT signature verification passed token_kid={kid} verification_kid={key_kid}')
break
except ExpiredSignatureError:
raise
except JWTError as exception:
last_error = exception
self._logger.warning(f'JWT signature verification failed token_kid={kid} verification_kid={key_kid} error={str(exception)}')
if payload is None:
self._logger.warning(f'JWT decode failed token_kid={kid} error={str(last_error)}')
raise InvalidTokenException()
if 'sid' not in payload:
self._logger.warning(f'JWT missing sid claim kid={kid}')
@@ -97,7 +111,7 @@ class JwtService(IJwtService):
self._logger.warning(f'JWT missing type claim kid={kid}')
raise InvalidTokenException('Missing token claim: type')
self._logger.info(f'JWT signature verification completed kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}')
self._logger.info(f'JWT signature verification completed token_kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}')
return payload
except ExpiredSignatureError as exception:

View File

@@ -81,17 +81,63 @@ class JwtKeyStore:
return None
current = self._vault_client.read_secret(self._kid_path)
kid = current.get('kid')
if kid:
key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
public_key = key_data.get('public_key') or key_data.get('public')
if public_key:
self._keys[str(kid)] = str(public_key)
for kid in self._get_configured_kids(current):
await self.get_public_key_for_kid(kid)
return None
async def get_verification_keys(self,token_kid: str | None = None) -> list[tuple[str,str]]:
if self._vault_client is None:
return [(kid,key) for kid,key in self._keys.items()]
current = self._vault_client.read_secret(self._kid_path)
kids = self._get_configured_kids(current)
if not kids and token_kid:
kids = [token_kid]
result: list[tuple[str,str]] = []
for kid in kids:
public_key = await self.get_public_key_for_kid(kid)
if public_key:
result.append((kid,public_key))
return result
async def get_public_key_for_kid(self,kid: str) -> str | None:
return self._keys.get(kid)
cached = self._keys.get(kid)
if cached:
return cached
if self._vault_client is None:
return None
try:
key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
except (hvac.exceptions.InvalidPath,hvac.exceptions.Forbidden):
return None
public_key = (
key_data.get('public_key')
or key_data.get('public')
or key_data.get('key')
or key_data.get('pem')
)
if public_key is None and len(key_data) == 1:
public_key = next(iter(key_data.values()))
if public_key is None:
return None
self._keys[kid] = str(public_key)
return self._keys[kid]
def _get_configured_kids(self,current: dict[str,Any]) -> list[str]:
active = current.get('active') or current.get('kid') or current.get('current')
previous = current.get('previous') or current.get('prev')
result: list[str] = []
for kid in (active,previous):
if kid and str(kid) not in result:
result.append(str(kid))
return result
def start_jwt_keys_scheduler(jwt_store: JwtKeyStore,refresh_seconds: int) -> AsyncIOScheduler: