feat: update jwt logic
This commit is contained in:
@@ -53,19 +53,17 @@ class JwtService(IJwtService):
|
|||||||
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
self._logger.warning(f'JWT invalid algorithm kid={kid} received_alg={received_alg} expected_alg={settings.JWT_ALGORITHM}')
|
||||||
raise InvalidTokenException('Invalid token algorithm')
|
raise InvalidTokenException('Invalid token algorithm')
|
||||||
|
|
||||||
self._logger.debug(f'JWT public key lookup started kid={kid}')
|
self._logger.debug(f'JWT verification keys lookup started token_kid={kid}')
|
||||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
verification_keys = await self._key_store.get_verification_keys(str(kid))
|
||||||
|
if not verification_keys:
|
||||||
if not public_pem:
|
self._logger.info(f'JWT verification keys miss token_kid={kid} forcing keystore refresh')
|
||||||
self._logger.info(f'JWT kid miss kid={kid} forcing keystore refresh')
|
|
||||||
await self._key_store.refresh()
|
await self._key_store.refresh()
|
||||||
public_pem = await self._key_store.get_public_key_for_kid(str(kid))
|
verification_keys = await self._key_store.get_verification_keys(str(kid))
|
||||||
|
|
||||||
if not public_pem:
|
if not verification_keys:
|
||||||
self._logger.warning(f'JWT unknown kid kid={kid}')
|
self._logger.warning(f'JWT no verification keys found token_kid={kid}')
|
||||||
raise InvalidTokenException('Unknown token kid')
|
raise InvalidTokenException('Unknown token kid')
|
||||||
|
|
||||||
self._logger.debug(f'JWT signature verification started kid={kid}')
|
|
||||||
options = {
|
options = {
|
||||||
'verify_signature': True,
|
'verify_signature': True,
|
||||||
'verify_exp': True,
|
'verify_exp': True,
|
||||||
@@ -80,6 +78,11 @@ class JwtService(IJwtService):
|
|||||||
'leeway': 10,
|
'leeway': 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
payload: dict | None = None
|
||||||
|
last_error: JWTError | None = None
|
||||||
|
for key_kid,public_pem in verification_keys:
|
||||||
|
try:
|
||||||
|
self._logger.debug(f'JWT signature verification started token_kid={kid} verification_kid={key_kid}')
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token,
|
token,
|
||||||
public_pem,
|
public_pem,
|
||||||
@@ -88,6 +91,17 @@ class JwtService(IJwtService):
|
|||||||
issuer=settings.JWT_ISSUER or None,
|
issuer=settings.JWT_ISSUER or None,
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
self._logger.info(f'JWT signature verification passed token_kid={kid} verification_kid={key_kid}')
|
||||||
|
break
|
||||||
|
except ExpiredSignatureError:
|
||||||
|
raise
|
||||||
|
except JWTError as exception:
|
||||||
|
last_error = exception
|
||||||
|
self._logger.warning(f'JWT signature verification failed token_kid={kid} verification_kid={key_kid} error={str(exception)}')
|
||||||
|
|
||||||
|
if payload is None:
|
||||||
|
self._logger.warning(f'JWT decode failed token_kid={kid} error={str(last_error)}')
|
||||||
|
raise InvalidTokenException()
|
||||||
|
|
||||||
if 'sid' not in payload:
|
if 'sid' not in payload:
|
||||||
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
self._logger.warning(f'JWT missing sid claim kid={kid}')
|
||||||
@@ -97,7 +111,7 @@ class JwtService(IJwtService):
|
|||||||
self._logger.warning(f'JWT missing type claim kid={kid}')
|
self._logger.warning(f'JWT missing type claim kid={kid}')
|
||||||
raise InvalidTokenException('Missing token claim: type')
|
raise InvalidTokenException('Missing token claim: type')
|
||||||
|
|
||||||
self._logger.info(f'JWT signature verification completed kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}')
|
self._logger.info(f'JWT signature verification completed token_kid={kid} user_id={payload.get("sub")} sid={payload.get("sid")}')
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
except ExpiredSignatureError as exception:
|
except ExpiredSignatureError as exception:
|
||||||
|
|||||||
@@ -81,17 +81,63 @@ class JwtKeyStore:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
current = self._vault_client.read_secret(self._kid_path)
|
current = self._vault_client.read_secret(self._kid_path)
|
||||||
kid = current.get('kid')
|
for kid in self._get_configured_kids(current):
|
||||||
if kid:
|
await self.get_public_key_for_kid(kid)
|
||||||
key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
|
|
||||||
public_key = key_data.get('public_key') or key_data.get('public')
|
|
||||||
if public_key:
|
|
||||||
self._keys[str(kid)] = str(public_key)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_verification_keys(self,token_kid: str | None = None) -> list[tuple[str,str]]:
|
||||||
|
if self._vault_client is None:
|
||||||
|
return [(kid,key) for kid,key in self._keys.items()]
|
||||||
|
|
||||||
|
current = self._vault_client.read_secret(self._kid_path)
|
||||||
|
kids = self._get_configured_kids(current)
|
||||||
|
if not kids and token_kid:
|
||||||
|
kids = [token_kid]
|
||||||
|
|
||||||
|
result: list[tuple[str,str]] = []
|
||||||
|
for kid in kids:
|
||||||
|
public_key = await self.get_public_key_for_kid(kid)
|
||||||
|
if public_key:
|
||||||
|
result.append((kid,public_key))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_public_key_for_kid(self,kid: str) -> str | None:
|
async def get_public_key_for_kid(self,kid: str) -> str | None:
|
||||||
return self._keys.get(kid)
|
cached = self._keys.get(kid)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
if self._vault_client is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_data = self._vault_client.read_secret(f'{self._kids_prefix}/{kid}')
|
||||||
|
except (hvac.exceptions.InvalidPath,hvac.exceptions.Forbidden):
|
||||||
|
return None
|
||||||
|
|
||||||
|
public_key = (
|
||||||
|
key_data.get('public_key')
|
||||||
|
or key_data.get('public')
|
||||||
|
or key_data.get('key')
|
||||||
|
or key_data.get('pem')
|
||||||
|
)
|
||||||
|
if public_key is None and len(key_data) == 1:
|
||||||
|
public_key = next(iter(key_data.values()))
|
||||||
|
if public_key is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._keys[kid] = str(public_key)
|
||||||
|
return self._keys[kid]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_configured_kids(self,current: dict[str,Any]) -> list[str]:
|
||||||
|
active = current.get('active') or current.get('kid') or current.get('current')
|
||||||
|
previous = current.get('previous') or current.get('prev')
|
||||||
|
result: list[str] = []
|
||||||
|
for kid in (active,previous):
|
||||||
|
if kid and str(kid) not in result:
|
||||||
|
result.append(str(kid))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def start_jwt_keys_scheduler(jwt_store: JwtKeyStore,refresh_seconds: int) -> AsyncIOScheduler:
|
def start_jwt_keys_scheduler(jwt_store: JwtKeyStore,refresh_seconds: int) -> AsyncIOScheduler:
|
||||||
|
|||||||
Reference in New Issue
Block a user