From 927b6acc09d55ed6d48e30b96d3aeddc9c40e71b Mon Sep 17 00:00:00 2001 From: Noloquideus Date: Tue, 12 May 2026 20:05:41 +0300 Subject: [PATCH] feat: update jwt logic --- src/infrastructure/security/jwt.py | 50 ++++++++++++++--------- src/infrastructure/vault/__init__.py | 60 ++++++++++++++++++++++++---- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/src/infrastructure/security/jwt.py b/src/infrastructure/security/jwt.py index c124737..a8af06a 100644 --- a/src/infrastructure/security/jwt.py +++ b/src/infrastructure/security/jwt.py @@ -53,19 +53,17 @@ class JwtService(IJwtService): self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}') raise InvalidTokenException('Invalid token algorithm') - self._logger.debug(f'JWT public key lookup started kid={kid}') - public_pem = await self._key_store.get_public_key_for_kid(str(kid)) - - if not public_pem: - self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh') + self._logger.debug(f'JWT verification keys lookup started token_kid={kid}') + verification_keys = await self._key_store.get_verification_keys(str(kid)) + if not verification_keys: + self._logger.info(f'JWT verification keys miss token_kid={kid} forcing keystore refresh') await self._key_store.refresh() - public_pem = await self._key_store.get_public_key_for_kid(str(kid)) + verification_keys = await self._key_store.get_verification_keys(str(kid)) - if not public_pem: - self._logger.warning(f'JWT unknown kid kid={kid}') + if not verification_keys: + self._logger.warning(f'JWT no verification keys found token_kid={kid}') raise InvalidTokenException('Unknown token kid') - self._logger.debug(f'JWT signature verification started kid={kid}') options = { 'verify_signature': True, 'verify_exp': True, @@ -80,14 +78,30 @@ class JwtService(IJwtService): 'leeway': 10, } - payload = jwt.decode( - token, - public_pem, - algorithms=[settings.JWT_ALGORITHM], - audience=settings.JWT_AUDIENCE or None, - issuer=settings.JWT_ISSUER or None, - options=options, - ) + payload: dict | None = None + last_error: JWTError | None = None + for key_kid,public_pem in verification_keys: + try: + self._logger.debug(f'JWT signature verification started token_kid={kid} verification_kid={key_kid}') + payload = jwt.decode( + token, + public_pem, + algorithms=[settings.JWT_ALGORITHM], + audience=settings.JWT_AUDIENCE or None, + issuer=settings.JWT_ISSUER or None, + options=options, + ) + self._logger.info(f'JWT signature verification passed token_kid={kid} verification_kid={key_kid}') + break + except ExpiredSignatureError: + raise + except JWTError as exception: + last_error = exception + self._logger.warning(f'JWT signature verification failed token_kid={kid} verification_kid={key_kid} error={str(exception)}') + + if payload is None: + self._logger.warning(f'JWT decode failed token_kid={kid} error={str(last_error)}') + raise InvalidTokenException() if 'sid' not in payload: self._logger.warning(f'JWT missing sid claim kid={kid}') @@ -97,7 +111,7 @@ class JwtService(IJwtService): self._logger.warning(f'JWT missing type claim kid={kid}') raise InvalidTokenException('Missing token claim: type') - self._logger.info(f'JWT signature verification completed kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}') + self._logger.info(f'JWT signature verification completed token_kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}') return payload except ExpiredSignatureError as exception: diff --git a/src/infrastructure/vault/__init__.py b/src/infrastructure/vault/__init__.py index fc8df50..667b2c3 100644 --- a/src/infrastructure/vault/__init__.py +++ b/src/infrastructure/vault/__init__.py @@ -81,17 +81,63 @@ class JwtKeyStore: return None current = self._vault_client.read_secret(self._kid_path) - kid = current.get('kid') - if kid: - key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}') - public_key = key_data.get('public_key') or key_data.get('public') - if public_key: - self._keys[str(kid)] = str(public_key) + for kid in self._get_configured_kids(current): + await self.get_public_key_for_kid(kid) return None + async def get_verification_keys(self,token_kid: str | None = None) -> list[tuple[str,str]]: + if self._vault_client is None: + return [(kid,key) for kid,key in self._keys.items()] + + current = self._vault_client.read_secret(self._kid_path) + kids = self._get_configured_kids(current) + if not kids and token_kid: + kids = [token_kid] + + result: list[tuple[str,str]] = [] + for kid in kids: + public_key = await self.get_public_key_for_kid(kid) + if public_key: + result.append((kid,public_key)) + return result + + async def get_public_key_for_kid(self,kid: str) -> str | None: - return self._keys.get(kid) + cached = self._keys.get(kid) + if cached: + return cached + if self._vault_client is None: + return None + + try: + key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}') + except (hvac.exceptions.InvalidPath,hvac.exceptions.Forbidden): + return None + + public_key = ( + key_data.get('public_key') + or key_data.get('public') + or key_data.get('key') + or key_data.get('pem') + ) + if public_key is None and len(key_data) == 1: + public_key = next(iter(key_data.values())) + if public_key is None: + return None + + self._keys[kid] = str(public_key) + return self._keys[kid] + + + def _get_configured_kids(self,current: dict[str,Any]) -> list[str]: + active = current.get('active') or current.get('kid') or current.get('current') + previous = current.get('previous') or current.get('prev') + result: list[str] = [] + for kid in (active,previous): + if kid and str(kid) not in result: + result.append(str(kid)) + return result def start_jwt_keys_scheduler(jwt_store: JwtKeyStore,refresh_seconds: int) -> AsyncIOScheduler: