Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e9b99535b9 | |||
| caf7f003fa | |||
| 666f2f67cb | |||
| d3fedc3f91 | |||
| 0cca0fedbe | |||
| 9166b21249 | |||
| 603efa55e6 | |||
| cec8d896b6 | |||
| 2d9f44979c | |||
| 8347ff40f4 | |||
| 4cfab85812 | |||
| 9d56b7f6f5 | |||
| bedce9e910 | |||
| 1724d4e37d | |||
| 57bafec204 | |||
| 0130912555 | |||
| 54ebcaeb81 | |||
| 3e3b9eb030 | |||
| f92eadf8fa | |||
| 949b57e425 | |||
| 3fc1b455d2 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,7 +2,7 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
generate_password_hash.py
|
||||
# C extensions
|
||||
*.so
|
||||
*.pyd
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
services:
|
||||
auth:
|
||||
container_name: auth-service
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
PYTHONUNBUFFERED: "1"
|
||||
APP_MODULE: "src.main:app"
|
||||
APP_HOST: "0.0.0.0"
|
||||
APP_PORT: "8000"
|
||||
APP_WORKERS: "1"
|
||||
env_file:
|
||||
- .env
|
||||
depends_on:
|
||||
keydb:
|
||||
condition: service_healthy
|
||||
restart: no
|
||||
|
||||
keydb:
|
||||
image: eqalpha/keydb
|
||||
container_name: keydb
|
||||
restart: no
|
||||
expose:
|
||||
- "6379"
|
||||
volumes:
|
||||
- keydb_data:/data
|
||||
command:
|
||||
- keydb-server
|
||||
- --requirepass
|
||||
- keydb
|
||||
- --dir
|
||||
- /data
|
||||
- --appendonly
|
||||
- "yes"
|
||||
- --appendfsync
|
||||
- everysec
|
||||
- --save
|
||||
- "900"
|
||||
- "1"
|
||||
- --save
|
||||
- "300"
|
||||
- "10"
|
||||
- --save
|
||||
- "60"
|
||||
- "10000"
|
||||
healthcheck:
|
||||
test: [ "CMD", "redis-cli", "-a", "keydb", "ping" ]
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
retries: 20
|
||||
|
||||
# keydb:
|
||||
# image: eqalpha/keydb
|
||||
# container_name: keydb
|
||||
# restart: no
|
||||
# expose:
|
||||
# - "6379"
|
||||
# volumes:
|
||||
# - keydb_data:/data
|
||||
# environment:
|
||||
# KEYDB_PASSWORD: keydb
|
||||
# command: >
|
||||
# sh -c "
|
||||
# keydb-server
|
||||
# --requirepass $$KEYDB_PASSWORD
|
||||
# --dir /data
|
||||
# --appendonly yes
|
||||
# --appendfsync everysec
|
||||
# --save 900 1
|
||||
# --save 300 10
|
||||
# --save 60 10000
|
||||
# "
|
||||
# healthcheck:
|
||||
# test: ["CMD", "redis-cli", "ping"]
|
||||
# interval: 5s
|
||||
# timeout: 2s
|
||||
# retries: 20
|
||||
|
||||
volumes:
|
||||
keydb_data:
|
||||
@@ -4,6 +4,8 @@ version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
requires-python = "==3.12.*"
|
||||
dependencies = [
|
||||
"acryl-datahub>=1.5.0.19",
|
||||
"acryl-sqlglot>=25.25.2.dev9",
|
||||
"apscheduler==3.11.2",
|
||||
"asyncpg==0.31.0",
|
||||
"bcrypt==5.0.0",
|
||||
|
||||
@@ -1,18 +1,32 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from ulid import ULID
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger
|
||||
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
|
||||
from src.application.domain.dto import RefreshTokenPayload
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.database.decorators import transactional
|
||||
|
||||
|
||||
class JwtRefreshCommand:
|
||||
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
|
||||
_LOCK_PREFIX = 'jwt:refresh:lock:'
|
||||
_LOCK_TTL_SECONDS = 15
|
||||
_LOCK_WAIT_ATTEMPTS = 40
|
||||
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unit_of_work: IUnitOfWork,
|
||||
hash_service: IHashService,
|
||||
jwt_service: IJwtService,
|
||||
cache: ICache,
|
||||
logger: ILogger,
|
||||
):
|
||||
self._unit_of_work = unit_of_work
|
||||
self._hash_service = hash_service
|
||||
self._jwt_service = jwt_service
|
||||
self._cache = cache
|
||||
self._logger = logger
|
||||
|
||||
@transactional
|
||||
@@ -25,6 +39,39 @@ class JwtRefreshCommand:
|
||||
user_id = payload.sub
|
||||
jti = payload.jti
|
||||
|
||||
lock_key = f'{self._LOCK_PREFIX}{sid}'
|
||||
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
|
||||
|
||||
if not locked:
|
||||
for _ in range(self._LOCK_WAIT_ATTEMPTS):
|
||||
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
|
||||
if await self._cache.get(lock_key) is None:
|
||||
self._logger.info(f'Concurrent refresh skipped (sid={sid})')
|
||||
raise RefreshConcurrentException()
|
||||
raise ApplicationException(status_code=429, message='Refresh in progress')
|
||||
|
||||
try:
|
||||
return await self._refresh_locked(
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
jti=jti,
|
||||
now=now,
|
||||
ip=ip,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
finally:
|
||||
await self._cache.delete(lock_key)
|
||||
|
||||
async def _refresh_locked(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
user_id: str,
|
||||
jti: str,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> tuple[str, str]:
|
||||
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
|
||||
if sess is None:
|
||||
raise ApplicationException(status_code=401, message='Session not found')
|
||||
@@ -61,7 +108,8 @@ class JwtRefreshCommand:
|
||||
)
|
||||
|
||||
if not rotated:
|
||||
raise ApplicationException(status_code=401, message='Refresh already rotated')
|
||||
self._logger.info(f'Refresh already rotated (sid={sid})')
|
||||
raise RefreshConcurrentException()
|
||||
|
||||
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
|
||||
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)
|
||||
|
||||
@@ -102,12 +102,12 @@ class UserLoginCompleteCommand:
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
encrypted_mnemonic=user.encrypted_mnemonic,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
passport_data=user.passport_data,
|
||||
inn=user.inn,
|
||||
erc20=user.erc20,
|
||||
avatar_link=user.avatar_link,
|
||||
kyc_verified=user.kyc_verified,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
created_at=user.created_at,
|
||||
|
||||
@@ -102,8 +102,6 @@ class UserLoginStartCommand:
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'payload: {payload})')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
@@ -123,7 +121,7 @@ class UserLoginStartCommand:
|
||||
self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
self._logger.info(f'login code created for {email}')
|
||||
self._logger.info(f'Login email verification code queued email={email}')
|
||||
return True
|
||||
|
||||
self._logger.error(f'login start failed: code space exhausted for {email}')
|
||||
|
||||
@@ -18,7 +18,7 @@ class UserLogoutCommand:
|
||||
if not refresh_token:
|
||||
return
|
||||
try:
|
||||
payload: RefreshTokenPayload = self._jwt_service.decode_refresh_token(refresh_token)
|
||||
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
|
||||
except ApplicationException:
|
||||
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
|
||||
return
|
||||
|
||||
@@ -45,7 +45,7 @@ class UserRegistrationCompleteCommand:
|
||||
|
||||
cached_email = await self._cache.get(code_key)
|
||||
if not cached_email:
|
||||
self._logger.info(f'Registration failed: code not found (email={email}, code={code})')
|
||||
self._logger.info(f'Registration failed: code not found (email={email})')
|
||||
raise ApplicationException(400, 'Invalid or expired code')
|
||||
|
||||
if cached_email != email:
|
||||
|
||||
@@ -102,8 +102,6 @@ class UserRegistrationStartCommand:
|
||||
'metadata': metadata,
|
||||
}
|
||||
|
||||
self._logger.info(f'payload: {payload})')
|
||||
|
||||
try:
|
||||
await self._messanger.publish_to_queue(
|
||||
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
|
||||
@@ -123,7 +121,7 @@ class UserRegistrationStartCommand:
|
||||
self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}')
|
||||
raise ApplicationException(503, 'Temporary error. Please try again.')
|
||||
|
||||
self._logger.info(f'Registration code created for {email}')
|
||||
self._logger.info(f'Registration email verification code queued email={email}')
|
||||
return True
|
||||
|
||||
self._logger.error(f'Registration start failed: code space exhausted for {email}')
|
||||
|
||||
@@ -18,12 +18,12 @@ class UserLoginDto:
|
||||
middle_name: str | None = None
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
crypto_wallet: str | None = None
|
||||
encrypted_mnemonic: str | None = None
|
||||
phone: str | None = None
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
passport_data: str | None = None
|
||||
inn: str | None = None
|
||||
erc20: str | None = None
|
||||
avatar_link: str | None = None
|
||||
kyc_verified: bool | None = None
|
||||
access_token: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
@@ -14,13 +14,14 @@ class UserEntity:
|
||||
last_name: str | None = None
|
||||
birth_date: date | None = None
|
||||
|
||||
crypto_wallet: str | None = None
|
||||
encrypted_mnemonic: str | None = None
|
||||
phone: str | None = None
|
||||
|
||||
bik: str | None = None
|
||||
account_number: str | None = None
|
||||
card_number: str | None = None
|
||||
passport_data: str | None = None
|
||||
inn: str | None = None
|
||||
erc20: str | None = None
|
||||
|
||||
avatar_link: str | None = None
|
||||
|
||||
kyc_verified: bool | None = None
|
||||
is_deleted: bool | None = None
|
||||
@@ -28,3 +29,7 @@ class UserEntity:
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
kyc_verified_at: datetime | None = None
|
||||
|
||||
account_type: str = 'individual'
|
||||
provisioned_by: str | None = None
|
||||
provisioned_at: datetime | None = None
|
||||
|
||||
6
src/application/domain/enums/account_type.py
Normal file
6
src/application/domain/enums/account_type.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class AccountType(StrEnum):
|
||||
INDIVIDUAL = 'individual'
|
||||
LEGAL_ENTITY = 'legal_entity'
|
||||
@@ -1 +1,10 @@
|
||||
from src.application.domain.exceptions.application_exceptions import ApplicationException
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
from src.application.domain.exceptions.bad_request_exception import BadRequestException
|
||||
from src.application.domain.exceptions.conflict_exception import ConflictException
|
||||
from src.application.domain.exceptions.forbidden_exception import ForbiddenException
|
||||
from src.application.domain.exceptions.internal_server_exception import InternalServerException
|
||||
from src.application.domain.exceptions.not_found_exception import NotFoundException
|
||||
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
|
||||
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
|
||||
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
|
||||
from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException
|
||||
@@ -15,4 +15,4 @@ class ApplicationException(Exception):
|
||||
self.headers = headers
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status_code}: {self.message}"
|
||||
return f'{self.status_code}: {self.message}'
|
||||
16
src/application/domain/exceptions/bad_request_exception.py
Normal file
16
src/application/domain/exceptions/bad_request_exception.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class BadRequestException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Bad Request',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
16
src/application/domain/exceptions/conflict_exception.py
Normal file
16
src/application/domain/exceptions/conflict_exception.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class ConflictException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Conflict',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
16
src/application/domain/exceptions/forbidden_exception.py
Normal file
16
src/application/domain/exceptions/forbidden_exception.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class ForbiddenException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Forbidden',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class InternalServerException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Internal Server Error',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
16
src/application/domain/exceptions/not_found_exception.py
Normal file
16
src/application/domain/exceptions/not_found_exception.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class NotFoundException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Not Found',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,10 @@
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class RefreshConcurrentException(ApplicationException):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
status_code=status.HTTP_200_OK,
|
||||
message='Refresh already handled',
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class ServiceUnavailableException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Service Unavailable',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class TooManyRequestsException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Too Many Requests',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
16
src/application/domain/exceptions/unauthorized_exception.py
Normal file
16
src/application/domain/exceptions/unauthorized_exception.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Mapping
|
||||
from starlette import status
|
||||
from src.application.domain.exceptions.application_exception import ApplicationException
|
||||
|
||||
|
||||
class UnauthorizedException(ApplicationException):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = 'Unauthorized',
|
||||
headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
22
src/infrastructure/cache/client.py
vendored
22
src/infrastructure/cache/client.py
vendored
@@ -4,13 +4,15 @@ from src.infrastructure.config import settings
|
||||
|
||||
|
||||
def create_redis_client() -> Redis:
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=50,
|
||||
decode_responses=True,
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
)
|
||||
kw = {
|
||||
'max_connections': 50,
|
||||
'decode_responses': True,
|
||||
'socket_timeout': 5,
|
||||
'socket_connect_timeout': 5,
|
||||
'health_check_interval': 30,
|
||||
'retry_on_timeout': True,
|
||||
'socket_keepalive': True,
|
||||
}
|
||||
if settings.REDIS_PASSWORD:
|
||||
kw['password'] = settings.REDIS_PASSWORD
|
||||
return redis.from_url(settings.REDIS_URL, **kw)
|
||||
@@ -55,7 +55,11 @@ class Settings(BaseSettings):
|
||||
CSRF_COOKIE_HTTPONLY: bool = True
|
||||
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
|
||||
CSRF_COOKIE_PATH: str = '/'
|
||||
CSRF_COOKIE_DOMAIN: str | None = None
|
||||
CSRF_COOKIE_DOMAIN: str | None = '.elcsa.ru'
|
||||
|
||||
AUTH_COOKIE_SECURE: bool = False
|
||||
AUTH_COOKIE_DOMAIN: str | None = '.elcsa.ru'
|
||||
CORS_ALLOW_ORIGIN_REGEX: str = r'https?://([a-z0-9-]+\.)*elcsa\.ru(:\d+)?$'
|
||||
|
||||
DOCS_USERNAME: str = 'admin'
|
||||
DOCS_PASSWORD: str = 'admin'
|
||||
@@ -81,9 +85,6 @@ class Settings(BaseSettings):
|
||||
RABBIT_CONNECT_TIMEOUT: int = 5
|
||||
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
|
||||
|
||||
CORS_ORIGINS: str = 'http://localhost:3000'
|
||||
CORS_ALLOW_CREDENTIALS: bool = True
|
||||
|
||||
RATE_LIMIT_REQUESTS: int = 60
|
||||
RATE_LIMIT_WINDOW: int = 60
|
||||
|
||||
@@ -99,7 +100,33 @@ class Settings(BaseSettings):
|
||||
|
||||
@field_validator('CSRF_COOKIE_DOMAIN', mode='before')
|
||||
@classmethod
|
||||
def empty_csrf_domain_to_none(cls, v):
|
||||
def normalize_csrf_cookie_domain(cls, v):
|
||||
if v is None or (isinstance(v, str) and not v.strip()):
|
||||
return '.elcsa.ru'
|
||||
s = str(v).strip()
|
||||
sl = s.lower()
|
||||
if sl in ('.elcsa.ru', 'elcsa.ru'):
|
||||
return '.elcsa.ru'
|
||||
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
|
||||
return '.elcsa.ru'
|
||||
return s
|
||||
|
||||
@field_validator('AUTH_COOKIE_DOMAIN', mode='before')
|
||||
@classmethod
|
||||
def normalize_auth_cookie_domain(cls, v):
|
||||
if v is None or (isinstance(v, str) and not v.strip()):
|
||||
return '.elcsa.ru'
|
||||
s = str(v).strip()
|
||||
sl = s.lower()
|
||||
if sl in ('.elcsa.ru', 'elcsa.ru'):
|
||||
return '.elcsa.ru'
|
||||
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
|
||||
return '.elcsa.ru'
|
||||
return s
|
||||
|
||||
@field_validator('REDIS_PASSWORD', mode='before')
|
||||
@classmethod
|
||||
def empty_redis_password_to_none(cls, v):
|
||||
if v is None or (isinstance(v, str) and not v.strip()):
|
||||
return None
|
||||
return v
|
||||
@@ -215,13 +242,29 @@ class Settings(BaseSettings):
|
||||
rb_set('password', 'RABBIT_PASSWORD')
|
||||
rb_set('vhost', 'RABBIT_VHOST')
|
||||
|
||||
redis_secret = read_secret_optional('redis')
|
||||
if redis_secret:
|
||||
rd_ci = {str(k).lower(): v for k, v in redis_secret.items()}
|
||||
|
||||
def rd_set(field: str, env_key: str, *, as_int: bool = False) -> None:
|
||||
v = rd_ci.get(field)
|
||||
if v is None:
|
||||
return
|
||||
if isinstance(v, str) and not v.strip():
|
||||
return
|
||||
if as_int:
|
||||
data[env_key] = int(v)
|
||||
else:
|
||||
data[env_key] = str(v).strip()
|
||||
|
||||
rd_set('host', 'REDIS_HOST')
|
||||
rd_set('port', 'REDIS_PORT', as_int=True)
|
||||
rd_set('password', 'REDIS_PASSWORD')
|
||||
rd_set('db', 'REDIS_DB', as_int=True)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
return [o.strip() for o in self.CORS_ORIGINS.split(',') if o.strip()]
|
||||
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return (
|
||||
@@ -231,8 +274,7 @@ class Settings(BaseSettings):
|
||||
|
||||
@property
|
||||
def REDIS_URL(self) -> str:
|
||||
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
|
||||
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
|
||||
return f'redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
|
||||
|
||||
@property
|
||||
def RABBIT_URL(self) -> str:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from sqlalchemy import Boolean, Date, String, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Date, String, DateTime, Text
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
@@ -16,13 +19,18 @@ class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
|
||||
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
|
||||
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
avatar_link: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
account_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default='individual', default='individual')
|
||||
provisioned_by: Mapped[str | None] = mapped_column(String(26), nullable=True)
|
||||
provisioned_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
@@ -7,6 +7,7 @@ from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.application.domain.enums.account_type import AccountType
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
@@ -16,7 +17,11 @@ class UserRepository(IUserRepository):
|
||||
self._logger = logger
|
||||
|
||||
async def create_user(self, email: str, password_hash: str) -> UserEntity:
|
||||
user = UserModel(email=email, password_hash=password_hash)
|
||||
user = UserModel(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
account_type=AccountType.INDIVIDUAL,
|
||||
)
|
||||
self._session.add(user)
|
||||
try:
|
||||
await self._session.flush()
|
||||
@@ -25,7 +30,9 @@ class UserRepository(IUserRepository):
|
||||
email=user.email,
|
||||
created_at=user.created_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted
|
||||
is_deleted=user.is_deleted,
|
||||
avatar_link=user.avatar_link,
|
||||
account_type=user.account_type,
|
||||
)
|
||||
|
||||
except IntegrityError:
|
||||
@@ -67,17 +74,20 @@ class UserRepository(IUserRepository):
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
encrypted_mnemonic=user.encrypted_mnemonic,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
passport_data=user.passport_data,
|
||||
inn=user.inn,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
erc20=user.erc20,
|
||||
avatar_link=user.avatar_link,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
account_type=user.account_type,
|
||||
provisioned_by=user.provisioned_by,
|
||||
provisioned_at=user.provisioned_at,
|
||||
)
|
||||
|
||||
except ApplicationException:
|
||||
|
||||
@@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import RefreshConcurrentException
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
@@ -20,8 +21,10 @@ class UnitOfWork(IUnitOfWork):
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
if not isinstance(exc_val, RefreshConcurrentException):
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
if not isinstance(exc_val, RefreshConcurrentException):
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
|
||||
63
src/main.py
63
src/main.py
@@ -4,6 +4,7 @@ import secrets
|
||||
from typing import AsyncGenerator
|
||||
from fastapi import Depends, FastAPI, status
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
@@ -15,13 +16,66 @@ from src.infrastructure.utils import generate_instance_id
|
||||
from src.infrastructure.logger import logger
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.dependencies import get_rabbit
|
||||
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
|
||||
from src.presentation.handler import application_exception_handler
|
||||
from src.presentation.handler import unhandled_exception_handler
|
||||
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
|
||||
from src.presentation.routing import v1_router
|
||||
from src.presentation.schemas import ErrorResponse
|
||||
|
||||
|
||||
security = HTTPBasic()
|
||||
|
||||
ERROR_RESPONSES: dict[int, str] = {
|
||||
status.HTTP_400_BAD_REQUEST: 'Bad Request',
|
||||
status.HTTP_401_UNAUTHORIZED: 'Unauthorized',
|
||||
status.HTTP_403_FORBIDDEN: 'Forbidden',
|
||||
status.HTTP_404_NOT_FOUND: 'Not Found',
|
||||
status.HTTP_409_CONFLICT: 'Conflict',
|
||||
status.HTTP_429_TOO_MANY_REQUESTS: 'Too Many Requests',
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR: 'Internal Server Error',
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE: 'Service Unavailable',
|
||||
}
|
||||
|
||||
|
||||
def custom_openapi() -> dict:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
license_info=app.license_info,
|
||||
)
|
||||
components = openapi_schema.setdefault('components', {})
|
||||
schemas = components.setdefault('schemas', {})
|
||||
schemas['ErrorResponse'] = ErrorResponse.model_json_schema()
|
||||
|
||||
for path_item in openapi_schema.get('paths', {}).values():
|
||||
for operation in path_item.values():
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
responses = operation.setdefault('responses', {})
|
||||
for status_code, description in ERROR_RESPONSES.items():
|
||||
responses.setdefault(
|
||||
str(status_code),
|
||||
{
|
||||
'description': description,
|
||||
'content': {
|
||||
'application/json': {
|
||||
'schema': {
|
||||
'$ref': '#/components/schemas/ErrorResponse',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
|
||||
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
|
||||
@@ -117,8 +171,8 @@ app.add_middleware(
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list(),
|
||||
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
|
||||
allow_origin_regex=settings.CORS_ALLOW_ORIGIN_REGEX,
|
||||
allow_credentials=True,
|
||||
allow_methods=['*'],
|
||||
allow_headers=['*'],
|
||||
)
|
||||
@@ -152,3 +206,6 @@ async def ping() -> dict[str, str]:
|
||||
'message': 'pong',
|
||||
'status': 'ok',
|
||||
}
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
@@ -29,7 +29,7 @@ async def require_access_token(
|
||||
if not token:
|
||||
raise ApplicationException(status_code=401, message="Not authenticated")
|
||||
|
||||
payload = jwt_service.decode_access_token(token)
|
||||
payload = await jwt_service.decode_access_token(token)
|
||||
if payload.type != "access":
|
||||
raise ApplicationException(status_code=401, message="Invalid token type")
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ def get_jwt_refresh_command(
|
||||
uow: IUnitOfWork = Depends(get_unit_of_work),
|
||||
hash_service: IHashService = Depends(get_hash_service),
|
||||
jwt_service: IJwtService = Depends(get_jwt_service),
|
||||
cache: ICache = Depends(get_cache),
|
||||
logger: ILogger = Depends(get_logger),
|
||||
) -> JwtRefreshCommand:
|
||||
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)
|
||||
return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)
|
||||
|
||||
2
src/presentation/handler/__init__.py
Normal file
2
src/presentation/handler/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.presentation.handler.application_exception_handler import application_exception_handler
|
||||
from src.presentation.handler.unhandled_exception_handler import unhandled_exception_handler
|
||||
@@ -1,17 +1,15 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
|
||||
|
||||
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
|
||||
detail = exc.message
|
||||
if 500 <= exc.status_code:
|
||||
detail = "Internal Server Error"
|
||||
detail = 'Internal Server Error'
|
||||
|
||||
return ORJSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": detail},
|
||||
content={'detail': detail},
|
||||
headers=dict(exc.headers) if exc.headers else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from src.infrastructure.logger import logger
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
|
||||
from src.presentation.handlers.application_handler import application_exception_handler
|
||||
@@ -12,7 +12,7 @@ from src.application.contracts import ILogger
|
||||
from src.application.domain.dto import UserLoginDto
|
||||
from src.infrastructure.config import settings
|
||||
from src.infrastructure.logger import get_logger
|
||||
from src.presentation.decorators import rate_limit, email_rl_key
|
||||
from src.presentation.decorators import csrf_protect,rate_limit,email_rl_key
|
||||
from src.presentation.dependencies import (
|
||||
get_user_registration_complete_command,
|
||||
get_user_logout_command,
|
||||
@@ -31,19 +31,23 @@ auth_router = APIRouter(prefix='/auth', tags=['auth'])
|
||||
response_class=ORJSONResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
|
||||
#@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
#@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
|
||||
@csrf_protect()
|
||||
async def registration_start(
|
||||
request: Request,
|
||||
body: RegistrationStart,
|
||||
logger: ILogger = Depends(get_logger),
|
||||
command: UserRegistrationStartCommand = Depends(get_user_registration_start_command),
|
||||
):
|
||||
logger.info('AHAHAHAHAHAHHAHAAH')
|
||||
result = await command(body.email)
|
||||
|
||||
return {'success': result}
|
||||
|
||||
@auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED)
|
||||
@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
#@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
@csrf_protect()
|
||||
async def registration(
|
||||
request: Request,
|
||||
user: RegistrationComplete,
|
||||
@@ -76,9 +80,10 @@ async def registration(
|
||||
key='device_id',
|
||||
value=device_id,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=60 * 60 * 24 * 365 * 5
|
||||
)
|
||||
|
||||
@@ -86,25 +91,28 @@ async def registration(
|
||||
key='access_token',
|
||||
value=created.access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
response.set_cookie(
|
||||
key='refresh_token',
|
||||
value=created.refresh_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
return response
|
||||
|
||||
@auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
|
||||
#@rate_limit(limit=5, window_seconds=60, scope='ip')
|
||||
#@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
|
||||
@csrf_protect()
|
||||
async def login_start(
|
||||
request: Request,
|
||||
body: LoginStart,
|
||||
@@ -114,8 +122,9 @@ async def login_start(
|
||||
|
||||
return {'success': result}
|
||||
|
||||
@auth_router.post(path='/login/compete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
@auth_router.post(path='/login/complete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
#@rate_limit(limit=10, window_seconds=300, scope='ip')
|
||||
@csrf_protect()
|
||||
async def login(
|
||||
request: Request,
|
||||
user: UserLogin,
|
||||
@@ -150,12 +159,12 @@ async def login(
|
||||
'middle_name': dto.middle_name,
|
||||
'last_name': dto.last_name,
|
||||
'birth_date': dto.birth_date.isoformat() if dto.birth_date else None,
|
||||
'crypto_wallet': dto.crypto_wallet,
|
||||
'encrypted_mnemonic': dto.encrypted_mnemonic,
|
||||
'phone': dto.phone,
|
||||
'bik': dto.bik,
|
||||
'account_number': dto.account_number,
|
||||
'card_number': dto.card_number,
|
||||
'passport_data': dto.passport_data,
|
||||
'inn': dto.inn,
|
||||
'erc20': dto.erc20,
|
||||
'avatar_link': dto.avatar_link,
|
||||
'kyc_verified': dto.kyc_verified,
|
||||
'kyc_verified_at': dto.kyc_verified_at,
|
||||
'created_at': dto.created_at.isoformat() if dto.created_at else None,
|
||||
@@ -167,9 +176,10 @@ async def login(
|
||||
key='device_id',
|
||||
value=device_id,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=60 * 60 * 24 * 365 * 5
|
||||
)
|
||||
|
||||
@@ -177,9 +187,10 @@ async def login(
|
||||
key='access_token',
|
||||
value=dto.access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
|
||||
@@ -187,16 +198,18 @@ async def login(
|
||||
key='refresh_token',
|
||||
value=dto.refresh_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
#@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
@csrf_protect()
|
||||
async def logout_current(
|
||||
request: Request,
|
||||
command: UserLogoutCommand = Depends(get_user_logout_command),
|
||||
@@ -206,8 +219,8 @@ async def logout_current(
|
||||
await command(refresh_token=refresh_token)
|
||||
|
||||
response = ORJSONResponse({'ok': True})
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from src.infrastructure.security import CsrfService
|
||||
@@ -11,7 +11,7 @@ csrf_router = APIRouter(prefix='/csrf', tags=['csrf'])
|
||||
|
||||
@csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
async def issue_csrf_token():
|
||||
async def issue_csrf_token(request: Request):
|
||||
csrf = CsrfService()
|
||||
|
||||
token = csrf.issue()
|
||||
@@ -30,7 +30,7 @@ async def issue_csrf_token():
|
||||
httponly=settings.CSRF_COOKIE_HTTPONLY,
|
||||
samesite=settings.CSRF_COOKIE_SAMESITE,
|
||||
path=settings.CSRF_COOKIE_PATH,
|
||||
domain=settings.CSRF_COOKIE_DOMAIN,
|
||||
domain=settings.CSRF_COOKIE_DOMAIN or '.elcsa.ru',
|
||||
max_age=csrf.ttl_seconds,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,63 +2,72 @@ from fastapi import APIRouter, Request, Depends
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from starlette import status
|
||||
from src.application.commands import JwtRefreshCommand
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
|
||||
from src.infrastructure.config import settings
|
||||
from src.presentation.decorators import rate_limit
|
||||
from src.presentation.decorators import csrf_protect, rate_limit
|
||||
from src.presentation.dependencies import get_jwt_refresh_command
|
||||
|
||||
|
||||
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
|
||||
|
||||
|
||||
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
async def refresh_tokens(
|
||||
request: Request,
|
||||
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
|
||||
):
|
||||
refresh_token = request.cookies.get('refresh_token')
|
||||
def _clear_auth_cookies(response: ORJSONResponse) -> None:
|
||||
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
|
||||
|
||||
if not refresh_token:
|
||||
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
return response
|
||||
|
||||
ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('user-agent')
|
||||
|
||||
try:
|
||||
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
|
||||
except ApplicationException:
|
||||
response = ORJSONResponse({'result': False}, status_code=401)
|
||||
response.delete_cookie('access_token', path='/')
|
||||
response.delete_cookie('refresh_token', path='/')
|
||||
return response
|
||||
|
||||
response = ORJSONResponse({'result': True})
|
||||
|
||||
def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None:
|
||||
response.set_cookie(
|
||||
key='access_token',
|
||||
value=access,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
|
||||
)
|
||||
response.set_cookie(
|
||||
key='refresh_token',
|
||||
value=refresh,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
secure=settings.AUTH_COOKIE_SECURE,
|
||||
samesite='lax',
|
||||
path='/',
|
||||
domain=settings.AUTH_COOKIE_DOMAIN,
|
||||
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
|
||||
)
|
||||
|
||||
|
||||
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
|
||||
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
|
||||
@csrf_protect()
|
||||
async def refresh_tokens(
|
||||
request: Request,
|
||||
command: JwtRefreshCommand = Depends(get_jwt_refresh_command),
|
||||
):
|
||||
refresh_token = request.cookies.get('refresh_token')
|
||||
|
||||
if not refresh_token:
|
||||
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
|
||||
_clear_auth_cookies(response)
|
||||
return response
|
||||
|
||||
# Usage
|
||||
# @jwt_router.get("/test")
|
||||
# async def profile(auth: AuthContext = Depends(require_access_token)):
|
||||
# return 'ok'
|
||||
ip = request.client.host if request.client else None
|
||||
user_agent = request.headers.get('user-agent')
|
||||
|
||||
try:
|
||||
tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
|
||||
except RefreshConcurrentException:
|
||||
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
|
||||
except ApplicationException as exc:
|
||||
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
|
||||
response = ORJSONResponse({'result': False}, status_code=401)
|
||||
_clear_auth_cookies(response)
|
||||
return response
|
||||
raise
|
||||
|
||||
access, refresh = tokens
|
||||
response = ORJSONResponse({'result': True})
|
||||
_set_auth_cookies(response, access, refresh)
|
||||
return response
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
from src.presentation.schemas.user import RegistrationStart, RegistrationComplete, UserLogin, LoginStart
|
||||
from src.presentation.schemas.error import ErrorResponse
|
||||
from src.presentation.schemas.user import RegistrationComplete
|
||||
from src.presentation.schemas.user import RegistrationStart
|
||||
from src.presentation.schemas.user import LoginStart
|
||||
from src.presentation.schemas.user import UserLogin
|
||||
6
src/presentation/schemas/error.py
Normal file
6
src/presentation/schemas/error.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
detail: str = Field(title='Detail')
|
||||
Reference in New Issue
Block a user