feat: add import
This commit is contained in:
@@ -1,61 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, Any, Optional, Annotated
|
||||
from fastapi import Request, Header
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.infrastructure.security import CsrfService
|
||||
|
||||
|
||||
def csrf_protect(
|
||||
expected_subject_getter: Optional[Callable[[Request], Optional[str]]] = None,
|
||||
):
|
||||
def decorator(func: Callable[..., Awaitable[Any]]):
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
has_request = any(p.annotation is Request or p.name == 'request' for p in params)
|
||||
if not has_request:
|
||||
raise RuntimeError('csrf_protect requires endpoint to accept `request: Request`')
|
||||
|
||||
has_header = any(p.name == 'x_csrf_token' for p in params)
|
||||
if not has_header:
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
name='x_csrf_token',
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=None,
|
||||
annotation=Annotated[str | None, Header(alias='X-CSRF-Token')],
|
||||
)
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request: Request | None = kwargs.get('request')
|
||||
if request is None:
|
||||
for arg in args:
|
||||
if isinstance(arg, Request):
|
||||
request = arg
|
||||
break
|
||||
|
||||
if request is None:
|
||||
raise ApplicationException(
|
||||
status_code=500,
|
||||
message='Request is required for CSRF protection',
|
||||
)
|
||||
|
||||
csrf = CsrfService()
|
||||
|
||||
cookie_token, _ = csrf.extract(request.cookies, request.headers)
|
||||
header_token = kwargs.get('x_csrf_token')
|
||||
|
||||
expected_subject = expected_subject_getter(request) if expected_subject_getter else None
|
||||
csrf.verify_pair(cookie_token, header_token, expected_subject)
|
||||
|
||||
kwargs.pop('x_csrf_token', None)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
wrapper.__signature__ = sig.replace(parameters=params)
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -5,6 +5,8 @@ from fastapi import Depends
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.commands import (
|
||||
AdminLoginCommand,
|
||||
AdminLogoutCommand,
|
||||
AdminJwtRefreshCommand,
|
||||
GetAdminMeCommand,
|
||||
CreateOrganizationCommand,
|
||||
CreateOrganizationWalletsCommand,
|
||||
@@ -19,9 +21,10 @@ from src.application.commands import (
|
||||
UpdatePurchaseRequestStatusCommand,
|
||||
UploadOrganizationDocumentCommand,
|
||||
)
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
||||
from src.application.contracts import ICache, IHashService, IJwtService, ILogger
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.storage.s3_documents_service import S3DocumentsService
|
||||
from src.presentation.dependencies.cache import get_cache
|
||||
from src.presentation.dependencies.logger import get_logger
|
||||
from src.presentation.dependencies.security import get_hash_service, get_jwt_service
|
||||
from src.presentation.dependencies.unit_of_work import get_unit_of_work
|
||||
@@ -60,6 +63,24 @@ def get_admin_me_command(
|
||||
return GetAdminMeCommand(uow, logger)
|
||||
|
||||
|
||||
def get_admin_logout_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> AdminLogoutCommand:
|
||||
return AdminLogoutCommand(uow, jwt_service, logger)
|
||||
|
||||
|
||||
def get_admin_jwt_refresh_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> AdminJwtRefreshCommand:
|
||||
return AdminJwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
|
||||
|
||||
|
||||
def get_create_organization_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
|
||||
Reference in New Issue
Block a user