21 Commits

Author SHA1 Message Date
e9b99535b9 feat: update auth for b2b 2026-06-02 23:43:44 +03:00
caf7f003fa feat: add validation 2026-05-19 22:29:02 +03:00
666f2f67cb feat: more workers 2026-05-14 23:46:54 +03:00
d3fedc3f91 feat: add new column and delete old 2026-05-13 11:19:34 +03:00
0cca0fedbe fix: add await 2026-05-12 22:47:50 +03:00
9166b21249 feat: update domains 2026-05-12 22:11:00 +03:00
603efa55e6 fix: crsf cookie 2026-05-12 22:03:02 +03:00
cec8d896b6 feat: update users 2026-05-12 21:44:18 +03:00
2d9f44979c feat: add domain in cookie and cors 2026-05-12 19:25:16 +03:00
8347ff40f4 fix: delete bank data and add passport data 2026-05-12 18:07:43 +03:00
4cfab85812 fix: rename path login 2026-05-12 12:33:57 +03:00
9d56b7f6f5 feat: add custom exceptions 2026-05-09 16:25:31 +03:00
bedce9e910 fix: add debug logger 2026-05-08 23:10:16 +03:00
1724d4e37d fix: off rate-limiting 2026-05-08 22:40:26 +03:00
57bafec204 feat: add csrf 2026-04-17 13:05:11 +03:00
0130912555 feat: add csrf 2026-04-17 12:02:26 +03:00
54ebcaeb81 fix: add request in csrf endpoint 2026-04-16 08:09:24 +03:00
3e3b9eb030 fix: delete origins 2026-04-15 15:13:08 +03:00
f92eadf8fa feat: add email code log 2026-04-13 13:35:32 +03:00
949b57e425 feat: add redis password 2026-04-12 16:20:07 +03:00
3fc1b455d2 chore: delete docker-compose 2026-04-12 14:18:09 +03:00
41 changed files with 1153 additions and 710 deletions

2
.gitignore vendored
View File

@@ -2,7 +2,7 @@
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
generate_password_hash.py
# C extensions # C extensions
*.so *.so
*.pyd *.pyd

View File

@@ -1,83 +0,0 @@
services:
auth:
container_name: auth-service
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
PYTHONUNBUFFERED: "1"
APP_MODULE: "src.main:app"
APP_HOST: "0.0.0.0"
APP_PORT: "8000"
APP_WORKERS: "1"
env_file:
- .env
depends_on:
keydb:
condition: service_healthy
restart: no
keydb:
image: eqalpha/keydb
container_name: keydb
restart: no
expose:
- "6379"
volumes:
- keydb_data:/data
command:
- keydb-server
- --requirepass
- keydb
- --dir
- /data
- --appendonly
- "yes"
- --appendfsync
- everysec
- --save
- "900"
- "1"
- --save
- "300"
- "10"
- --save
- "60"
- "10000"
healthcheck:
test: [ "CMD", "redis-cli", "-a", "keydb", "ping" ]
interval: 5s
timeout: 2s
retries: 20
# keydb:
# image: eqalpha/keydb
# container_name: keydb
# restart: no
# expose:
# - "6379"
# volumes:
# - keydb_data:/data
# environment:
# KEYDB_PASSWORD: keydb
# command: >
# sh -c "
# keydb-server
# --requirepass $$KEYDB_PASSWORD
# --dir /data
# --appendonly yes
# --appendfsync everysec
# --save 900 1
# --save 300 10
# --save 60 10000
# "
# healthcheck:
# test: ["CMD", "redis-cli", "ping"]
# interval: 5s
# timeout: 2s
# retries: 20
volumes:
keydb_data:

View File

@@ -4,6 +4,8 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
requires-python = "==3.12.*" requires-python = "==3.12.*"
dependencies = [ dependencies = [
"acryl-datahub>=1.5.0.19",
"acryl-sqlglot>=25.25.2.dev9",
"apscheduler==3.11.2", "apscheduler==3.11.2",
"asyncpg==0.31.0", "asyncpg==0.31.0",
"bcrypt==5.0.0", "bcrypt==5.0.0",

View File

@@ -1,18 +1,32 @@
import asyncio
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from ulid import ULID from ulid import ULID
from src.application.abstractions import IUnitOfWork from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import RefreshTokenPayload from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional from src.infrastructure.database.decorators import transactional
class JwtRefreshCommand: class JwtRefreshCommand:
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger): _LOCK_PREFIX = 'jwt:refresh:lock:'
_LOCK_TTL_SECONDS = 15
_LOCK_WAIT_ATTEMPTS = 40
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work self._unit_of_work = unit_of_work
self._hash_service = hash_service self._hash_service = hash_service
self._jwt_service = jwt_service self._jwt_service = jwt_service
self._cache = cache
self._logger = logger self._logger = logger
@transactional @transactional
@@ -25,6 +39,39 @@ class JwtRefreshCommand:
user_id = payload.sub user_id = payload.sub
jti = payload.jti jti = payload.jti
lock_key = f'{self._LOCK_PREFIX}{sid}'
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
if not locked:
for _ in range(self._LOCK_WAIT_ATTEMPTS):
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
if await self._cache.get(lock_key) is None:
self._logger.info(f'Concurrent refresh skipped (sid={sid})')
raise RefreshConcurrentException()
raise ApplicationException(status_code=429, message='Refresh in progress')
try:
return await self._refresh_locked(
sid=sid,
user_id=user_id,
jti=jti,
now=now,
ip=ip,
user_agent=user_agent,
)
finally:
await self._cache.delete(lock_key)
async def _refresh_locked(
self,
*,
sid: str,
user_id: str,
jti: str,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
sess = await self._unit_of_work.session_repository.get_by_sid(sid) sess = await self._unit_of_work.session_repository.get_by_sid(sid)
if sess is None: if sess is None:
raise ApplicationException(status_code=401, message='Session not found') raise ApplicationException(status_code=401, message='Session not found')
@@ -61,7 +108,8 @@ class JwtRefreshCommand:
) )
if not rotated: if not rotated:
raise ApplicationException(status_code=401, message='Refresh already rotated') self._logger.info(f'Refresh already rotated (sid={sid})')
raise RefreshConcurrentException()
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid) access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti) refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)

View File

@@ -102,12 +102,12 @@ class UserLoginCompleteCommand:
middle_name=user.middle_name, middle_name=user.middle_name,
last_name=user.last_name, last_name=user.last_name,
birth_date=user.birth_date, birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet, encrypted_mnemonic=user.encrypted_mnemonic,
phone=user.phone, phone=user.phone,
bik=user.bik, passport_data=user.passport_data,
account_number=user.account_number,
card_number=user.card_number,
inn=user.inn, inn=user.inn,
erc20=user.erc20,
avatar_link=user.avatar_link,
kyc_verified=user.kyc_verified, kyc_verified=user.kyc_verified,
kyc_verified_at=user.kyc_verified_at, kyc_verified_at=user.kyc_verified_at,
created_at=user.created_at, created_at=user.created_at,

View File

@@ -102,8 +102,6 @@ class UserLoginStartCommand:
'metadata': metadata, 'metadata': metadata,
} }
self._logger.info(f'payload: {payload})')
try: try:
await self._messanger.publish_to_queue( await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE, queue=settings.RABBIT_EMAIL_CODE_QUEUE,
@@ -123,7 +121,7 @@ class UserLoginStartCommand:
self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}') self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.') raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'login code created for {email}') self._logger.info(f'Login email verification code queued email={email}')
return True return True
self._logger.error(f'login start failed: code space exhausted for {email}') self._logger.error(f'login start failed: code space exhausted for {email}')

View File

@@ -18,7 +18,7 @@ class UserLogoutCommand:
if not refresh_token: if not refresh_token:
return return
try: try:
payload: RefreshTokenPayload = self._jwt_service.decode_refresh_token(refresh_token) payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
except ApplicationException: except ApplicationException:
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke') self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
return return

View File

@@ -45,7 +45,7 @@ class UserRegistrationCompleteCommand:
cached_email = await self._cache.get(code_key) cached_email = await self._cache.get(code_key)
if not cached_email: if not cached_email:
self._logger.info(f'Registration failed: code not found (email={email}, code={code})') self._logger.info(f'Registration failed: code not found (email={email})')
raise ApplicationException(400, 'Invalid or expired code') raise ApplicationException(400, 'Invalid or expired code')
if cached_email != email: if cached_email != email:

View File

@@ -102,8 +102,6 @@ class UserRegistrationStartCommand:
'metadata': metadata, 'metadata': metadata,
} }
self._logger.info(f'payload: {payload})')
try: try:
await self._messanger.publish_to_queue( await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE, queue=settings.RABBIT_EMAIL_CODE_QUEUE,
@@ -123,7 +121,7 @@ class UserRegistrationStartCommand:
self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}') self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.') raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'Registration code created for {email}') self._logger.info(f'Registration email verification code queued email={email}')
return True return True
self._logger.error(f'Registration start failed: code space exhausted for {email}') self._logger.error(f'Registration start failed: code space exhausted for {email}')

View File

@@ -18,12 +18,12 @@ class UserLoginDto:
middle_name: str | None = None middle_name: str | None = None
last_name: str | None = None last_name: str | None = None
birth_date: date | None = None birth_date: date | None = None
crypto_wallet: str | None = None encrypted_mnemonic: str | None = None
phone: str | None = None phone: str | None = None
bik: str | None = None passport_data: str | None = None
account_number: str | None = None
card_number: str | None = None
inn: str | None = None inn: str | None = None
erc20: str | None = None
avatar_link: str | None = None
kyc_verified: bool | None = None kyc_verified: bool | None = None
access_token: str | None = None access_token: str | None = None
refresh_token: str | None = None refresh_token: str | None = None

View File

@@ -14,13 +14,14 @@ class UserEntity:
last_name: str | None = None last_name: str | None = None
birth_date: date | None = None birth_date: date | None = None
crypto_wallet: str | None = None encrypted_mnemonic: str | None = None
phone: str | None = None phone: str | None = None
bik: str | None = None passport_data: str | None = None
account_number: str | None = None
card_number: str | None = None
inn: str | None = None inn: str | None = None
erc20: str | None = None
avatar_link: str | None = None
kyc_verified: bool | None = None kyc_verified: bool | None = None
is_deleted: bool | None = None is_deleted: bool | None = None
@@ -28,3 +29,7 @@ class UserEntity:
created_at: datetime | None = None created_at: datetime | None = None
updated_at: datetime | None = None updated_at: datetime | None = None
kyc_verified_at: datetime | None = None kyc_verified_at: datetime | None = None
account_type: str = 'individual'
provisioned_by: str | None = None
provisioned_at: datetime | None = None

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class AccountType(StrEnum):
INDIVIDUAL = 'individual'
LEGAL_ENTITY = 'legal_entity'

View File

@@ -1 +1,10 @@
from src.application.domain.exceptions.application_exceptions import ApplicationException from src.application.domain.exceptions.application_exception import ApplicationException
from src.application.domain.exceptions.bad_request_exception import BadRequestException
from src.application.domain.exceptions.conflict_exception import ConflictException
from src.application.domain.exceptions.forbidden_exception import ForbiddenException
from src.application.domain.exceptions.internal_server_exception import InternalServerException
from src.application.domain.exceptions.not_found_exception import NotFoundException
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException

View File

@@ -15,4 +15,4 @@ class ApplicationException(Exception):
self.headers = headers self.headers = headers
def __str__(self): def __str__(self):
return f"{self.status_code}: {self.message}" return f'{self.status_code}: {self.message}'

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class BadRequestException(ApplicationException):
def __init__(
self,
message: str = 'Bad Request',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ConflictException(ApplicationException):
def __init__(
self,
message: str = 'Conflict',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ForbiddenException(ApplicationException):
def __init__(
self,
message: str = 'Forbidden',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class InternalServerException(ApplicationException):
def __init__(
self,
message: str = 'Internal Server Error',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class NotFoundException(ApplicationException):
def __init__(
self,
message: str = 'Not Found',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,10 @@
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class RefreshConcurrentException(ApplicationException):
def __init__(self) -> None:
super().__init__(
status_code=status.HTTP_200_OK,
message='Refresh already handled',
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ServiceUnavailableException(ApplicationException):
def __init__(
self,
message: str = 'Service Unavailable',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class TooManyRequestsException(ApplicationException):
def __init__(
self,
message: str = 'Too Many Requests',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class UnauthorizedException(ApplicationException):
def __init__(
self,
message: str = 'Unauthorized',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
message=message,
headers=headers,
)

View File

@@ -4,13 +4,15 @@ from src.infrastructure.config import settings
def create_redis_client() -> Redis: def create_redis_client() -> Redis:
return redis.from_url( kw = {
settings.REDIS_URL, 'max_connections': 50,
max_connections=50, 'decode_responses': True,
decode_responses=True, 'socket_timeout': 5,
socket_timeout=5, 'socket_connect_timeout': 5,
socket_connect_timeout=5, 'health_check_interval': 30,
health_check_interval=30, 'retry_on_timeout': True,
retry_on_timeout=True, 'socket_keepalive': True,
socket_keepalive=True, }
) if settings.REDIS_PASSWORD:
kw['password'] = settings.REDIS_PASSWORD
return redis.from_url(settings.REDIS_URL, **kw)

View File

@@ -55,7 +55,11 @@ class Settings(BaseSettings):
CSRF_COOKIE_HTTPONLY: bool = True CSRF_COOKIE_HTTPONLY: bool = True
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax' CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
CSRF_COOKIE_PATH: str = '/' CSRF_COOKIE_PATH: str = '/'
CSRF_COOKIE_DOMAIN: str | None = None CSRF_COOKIE_DOMAIN: str | None = '.elcsa.ru'
AUTH_COOKIE_SECURE: bool = False
AUTH_COOKIE_DOMAIN: str | None = '.elcsa.ru'
CORS_ALLOW_ORIGIN_REGEX: str = r'https?://([a-z0-9-]+\.)*elcsa\.ru(:\d+)?$'
DOCS_USERNAME: str = 'admin' DOCS_USERNAME: str = 'admin'
DOCS_PASSWORD: str = 'admin' DOCS_PASSWORD: str = 'admin'
@@ -81,9 +85,6 @@ class Settings(BaseSettings):
RABBIT_CONNECT_TIMEOUT: int = 5 RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code' RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
CORS_ORIGINS: str = 'http://localhost:3000'
CORS_ALLOW_CREDENTIALS: bool = True
RATE_LIMIT_REQUESTS: int = 60 RATE_LIMIT_REQUESTS: int = 60
RATE_LIMIT_WINDOW: int = 60 RATE_LIMIT_WINDOW: int = 60
@@ -99,7 +100,33 @@ class Settings(BaseSettings):
@field_validator('CSRF_COOKIE_DOMAIN', mode='before') @field_validator('CSRF_COOKIE_DOMAIN', mode='before')
@classmethod @classmethod
def empty_csrf_domain_to_none(cls, v): def normalize_csrf_cookie_domain(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return '.elcsa.ru'
s = str(v).strip()
sl = s.lower()
if sl in ('.elcsa.ru', 'elcsa.ru'):
return '.elcsa.ru'
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
return '.elcsa.ru'
return s
@field_validator('AUTH_COOKIE_DOMAIN', mode='before')
@classmethod
def normalize_auth_cookie_domain(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return '.elcsa.ru'
s = str(v).strip()
sl = s.lower()
if sl in ('.elcsa.ru', 'elcsa.ru'):
return '.elcsa.ru'
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
return '.elcsa.ru'
return s
@field_validator('REDIS_PASSWORD', mode='before')
@classmethod
def empty_redis_password_to_none(cls, v):
if v is None or (isinstance(v, str) and not v.strip()): if v is None or (isinstance(v, str) and not v.strip()):
return None return None
return v return v
@@ -215,13 +242,29 @@ class Settings(BaseSettings):
rb_set('password', 'RABBIT_PASSWORD') rb_set('password', 'RABBIT_PASSWORD')
rb_set('vhost', 'RABBIT_VHOST') rb_set('vhost', 'RABBIT_VHOST')
redis_secret = read_secret_optional('redis')
if redis_secret:
rd_ci = {str(k).lower(): v for k, v in redis_secret.items()}
def rd_set(field: str, env_key: str, *, as_int: bool = False) -> None:
v = rd_ci.get(field)
if v is None:
return
if isinstance(v, str) and not v.strip():
return
if as_int:
data[env_key] = int(v)
else:
data[env_key] = str(v).strip()
rd_set('host', 'REDIS_HOST')
rd_set('port', 'REDIS_PORT', as_int=True)
rd_set('password', 'REDIS_PASSWORD')
rd_set('db', 'REDIS_DB', as_int=True)
return data return data
def cors_origins_list(self) -> List[str]:
return [o.strip() for o in self.CORS_ORIGINS.split(',') if o.strip()]
@property @property
def DATABASE_URL(self) -> str: def DATABASE_URL(self) -> str:
return ( return (
@@ -231,8 +274,7 @@ class Settings(BaseSettings):
@property @property
def REDIS_URL(self) -> str: def REDIS_URL(self) -> str:
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else "" return f'redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
@property @property
def RABBIT_URL(self) -> str: def RABBIT_URL(self) -> str:

View File

@@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import Boolean, Date, String, DateTime
from sqlalchemy.orm import Mapped, mapped_column from datetime import datetime
from sqlalchemy import Boolean, Date, String, DateTime, Text
from sqlalchemy.orm import Mapped,mapped_column
from src.infrastructure.database.models.base import Base from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin): class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
@@ -16,13 +19,18 @@ class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True) middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True) birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True) encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
phone: Mapped[str | None] = mapped_column(String(16), nullable=True) phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
bik: Mapped[str | None] = mapped_column(String(9), nullable=True) passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
inn: Mapped[str | None] = mapped_column(String(12), nullable=True) inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
avatar_link: Mapped[str | None] = mapped_column(Text, nullable=True)
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False) kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True) kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
account_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default='individual', default='individual')
provisioned_by: Mapped[str | None] = mapped_column(String(26), nullable=True)
provisioned_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View File

@@ -7,6 +7,7 @@ from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException
from src.application.abstractions.repositories import IUserRepository from src.application.abstractions.repositories import IUserRepository
from src.application.domain.entities import UserEntity from src.application.domain.entities import UserEntity
from src.application.domain.enums.account_type import AccountType
from src.infrastructure.database.models import UserModel from src.infrastructure.database.models import UserModel
@@ -16,7 +17,11 @@ class UserRepository(IUserRepository):
self._logger = logger self._logger = logger
async def create_user(self, email: str, password_hash: str) -> UserEntity: async def create_user(self, email: str, password_hash: str) -> UserEntity:
user = UserModel(email=email, password_hash=password_hash) user = UserModel(
email=email,
password_hash=password_hash,
account_type=AccountType.INDIVIDUAL,
)
self._session.add(user) self._session.add(user)
try: try:
await self._session.flush() await self._session.flush()
@@ -25,7 +30,9 @@ class UserRepository(IUserRepository):
email=user.email, email=user.email,
created_at=user.created_at, created_at=user.created_at,
kyc_verified=user.kyc_verified, kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted is_deleted=user.is_deleted,
avatar_link=user.avatar_link,
account_type=user.account_type,
) )
except IntegrityError: except IntegrityError:
@@ -67,17 +74,20 @@ class UserRepository(IUserRepository):
middle_name=user.middle_name, middle_name=user.middle_name,
last_name=user.last_name, last_name=user.last_name,
birth_date=user.birth_date, birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet, encrypted_mnemonic=user.encrypted_mnemonic,
phone=user.phone, phone=user.phone,
bik=user.bik, passport_data=user.passport_data,
account_number=user.account_number,
card_number=user.card_number,
inn=user.inn, inn=user.inn,
kyc_verified_at=user.kyc_verified_at, erc20=user.erc20,
avatar_link=user.avatar_link,
kyc_verified=user.kyc_verified, kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted, is_deleted=user.is_deleted,
created_at=user.created_at, created_at=user.created_at,
updated_at=user.updated_at, updated_at=user.updated_at,
kyc_verified_at=user.kyc_verified_at,
account_type=user.account_type,
provisioned_by=user.provisioned_by,
provisioned_at=user.provisioned_at,
) )
except ApplicationException: except ApplicationException:

View File

@@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import IUserRepository, ISessionRepository from src.application.abstractions.repositories import IUserRepository, ISessionRepository
from src.application.contracts import ILogger from src.application.contracts import ILogger
from src.application.domain.exceptions import RefreshConcurrentException
from src.infrastructure.database.repositories import UserRepository, SessionRepository from src.infrastructure.database.repositories import UserRepository, SessionRepository
@@ -20,8 +21,10 @@ class UnitOfWork(IUnitOfWork):
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type: if exc_type:
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(str(exc_val)) self._logger.error(str(exc_val))
await self._session.rollback() await self._session.rollback()
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(f'Rollback: str{exc_val})') self._logger.error(f'Rollback: str{exc_val})')
else: else:
await self._session.flush() await self._session.flush()

View File

@@ -4,6 +4,7 @@ import secrets
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import Depends, FastAPI, status from fastapi import Depends, FastAPI, status
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
@@ -15,13 +16,66 @@ from src.infrastructure.utils import generate_instance_id
from src.infrastructure.logger import logger from src.infrastructure.logger import logger
from src.infrastructure.config import settings from src.infrastructure.config import settings
from src.presentation.dependencies import get_rabbit from src.presentation.dependencies import get_rabbit
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler from src.presentation.handler import application_exception_handler
from src.presentation.handler import unhandled_exception_handler
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
from src.presentation.routing import v1_router from src.presentation.routing import v1_router
from src.presentation.schemas import ErrorResponse
security = HTTPBasic() security = HTTPBasic()
ERROR_RESPONSES: dict[int, str] = {
status.HTTP_400_BAD_REQUEST: 'Bad Request',
status.HTTP_401_UNAUTHORIZED: 'Unauthorized',
status.HTTP_403_FORBIDDEN: 'Forbidden',
status.HTTP_404_NOT_FOUND: 'Not Found',
status.HTTP_409_CONFLICT: 'Conflict',
status.HTTP_429_TOO_MANY_REQUESTS: 'Too Many Requests',
status.HTTP_500_INTERNAL_SERVER_ERROR: 'Internal Server Error',
status.HTTP_503_SERVICE_UNAVAILABLE: 'Service Unavailable',
}
def custom_openapi() -> dict:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
license_info=app.license_info,
)
components = openapi_schema.setdefault('components', {})
schemas = components.setdefault('schemas', {})
schemas['ErrorResponse'] = ErrorResponse.model_json_schema()
for path_item in openapi_schema.get('paths', {}).values():
for operation in path_item.values():
if not isinstance(operation, dict):
continue
responses = operation.setdefault('responses', {})
for status_code, description in ERROR_RESPONSES.items():
responses.setdefault(
str(status_code),
{
'description': description,
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/ErrorResponse',
},
},
},
},
)
app.openapi_schema = openapi_schema
return app.openapi_schema
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials: async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME) user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
@@ -117,8 +171,8 @@ app.add_middleware(
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.cors_origins_list(), allow_origin_regex=settings.CORS_ALLOW_ORIGIN_REGEX,
allow_credentials=settings.CORS_ALLOW_CREDENTIALS, allow_credentials=True,
allow_methods=['*'], allow_methods=['*'],
allow_headers=['*'], allow_headers=['*'],
) )
@@ -152,3 +206,6 @@ async def ping() -> dict[str, str]:
'message': 'pong', 'message': 'pong',
'status': 'ok', 'status': 'ok',
} }
app.openapi = custom_openapi

View File

@@ -29,7 +29,7 @@ async def require_access_token(
if not token: if not token:
raise ApplicationException(status_code=401, message="Not authenticated") raise ApplicationException(status_code=401, message="Not authenticated")
payload = jwt_service.decode_access_token(token) payload = await jwt_service.decode_access_token(token)
if payload.type != "access": if payload.type != "access":
raise ApplicationException(status_code=401, message="Invalid token type") raise ApplicationException(status_code=401, message="Invalid token type")

View File

@@ -93,6 +93,7 @@ def get_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work), uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service), hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service), jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger), logger: ILogger = Depends(get_logger),
) -> JwtRefreshCommand: ) -> JwtRefreshCommand:
return JwtRefreshCommand(uow, hash_service, jwt_service, logger) return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)

View File

@@ -0,0 +1,2 @@
from src.presentation.handler.application_exception_handler import application_exception_handler
from src.presentation.handler.unhandled_exception_handler import unhandled_exception_handler

View File

@@ -1,17 +1,15 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse: async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
detail = exc.message detail = exc.message
if 500 <= exc.status_code: if 500 <= exc.status_code:
detail = "Internal Server Error" detail = 'Internal Server Error'
return ORJSONResponse( return ORJSONResponse(
status_code=exc.status_code, status_code=exc.status_code,
content={"detail": detail}, content={'detail': detail},
headers=dict(exc.headers) if exc.headers else None, headers=dict(exc.headers) if exc.headers else None,
) )

View File

@@ -1,5 +1,5 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse
from starlette import status from starlette import status
from src.infrastructure.logger import logger from src.infrastructure.logger import logger

View File

@@ -1,2 +0,0 @@
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
from src.presentation.handlers.application_handler import application_exception_handler

View File

@@ -12,7 +12,7 @@ from src.application.contracts import ILogger
from src.application.domain.dto import UserLoginDto from src.application.domain.dto import UserLoginDto
from src.infrastructure.config import settings from src.infrastructure.config import settings
from src.infrastructure.logger import get_logger from src.infrastructure.logger import get_logger
from src.presentation.decorators import rate_limit, email_rl_key from src.presentation.decorators import csrf_protect,rate_limit,email_rl_key
from src.presentation.dependencies import ( from src.presentation.dependencies import (
get_user_registration_complete_command, get_user_registration_complete_command,
get_user_logout_command, get_user_logout_command,
@@ -31,19 +31,23 @@ auth_router = APIRouter(prefix='/auth', tags=['auth'])
response_class=ORJSONResponse, response_class=ORJSONResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
) )
@rate_limit(limit=5, window_seconds=60, scope='ip') #@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key) #@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
@csrf_protect()
async def registration_start( async def registration_start(
request: Request, request: Request,
body: RegistrationStart, body: RegistrationStart,
logger: ILogger = Depends(get_logger),
command: UserRegistrationStartCommand = Depends(get_user_registration_start_command), command: UserRegistrationStartCommand = Depends(get_user_registration_start_command),
): ):
logger.info('AHAHAHAHAHAHHAHAAH')
result = await command(body.email) result = await command(body.email)
return {'success': result} return {'success': result}
@auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED) @auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED)
@rate_limit(limit=10, window_seconds=300, scope='ip') #@rate_limit(limit=10, window_seconds=300, scope='ip')
@csrf_protect()
async def registration( async def registration(
request: Request, request: Request,
user: RegistrationComplete, user: RegistrationComplete,
@@ -76,9 +80,10 @@ async def registration(
key='device_id', key='device_id',
value=device_id, value=device_id,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=60 * 60 * 24 * 365 * 5 max_age=60 * 60 * 24 * 365 * 5
) )
@@ -86,25 +91,28 @@ async def registration(
key='access_token', key='access_token',
value=created.access_token, value=created.access_token,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS), max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
) )
response.set_cookie( response.set_cookie(
key='refresh_token', key='refresh_token',
value=created.refresh_token, value=created.refresh_token,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS), max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
) )
return response return response
@auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=5, window_seconds=60, scope='ip') #@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key) #@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
@csrf_protect()
async def login_start( async def login_start(
request: Request, request: Request,
body: LoginStart, body: LoginStart,
@@ -114,8 +122,9 @@ async def login_start(
return {'success': result} return {'success': result}
@auth_router.post(path='/login/compete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @auth_router.post(path='/login/complete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=10, window_seconds=300, scope='ip') #@rate_limit(limit=10, window_seconds=300, scope='ip')
@csrf_protect()
async def login( async def login(
request: Request, request: Request,
user: UserLogin, user: UserLogin,
@@ -150,12 +159,12 @@ async def login(
'middle_name': dto.middle_name, 'middle_name': dto.middle_name,
'last_name': dto.last_name, 'last_name': dto.last_name,
'birth_date': dto.birth_date.isoformat() if dto.birth_date else None, 'birth_date': dto.birth_date.isoformat() if dto.birth_date else None,
'crypto_wallet': dto.crypto_wallet, 'encrypted_mnemonic': dto.encrypted_mnemonic,
'phone': dto.phone, 'phone': dto.phone,
'bik': dto.bik, 'passport_data': dto.passport_data,
'account_number': dto.account_number,
'card_number': dto.card_number,
'inn': dto.inn, 'inn': dto.inn,
'erc20': dto.erc20,
'avatar_link': dto.avatar_link,
'kyc_verified': dto.kyc_verified, 'kyc_verified': dto.kyc_verified,
'kyc_verified_at': dto.kyc_verified_at, 'kyc_verified_at': dto.kyc_verified_at,
'created_at': dto.created_at.isoformat() if dto.created_at else None, 'created_at': dto.created_at.isoformat() if dto.created_at else None,
@@ -167,9 +176,10 @@ async def login(
key='device_id', key='device_id',
value=device_id, value=device_id,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=60 * 60 * 24 * 365 * 5 max_age=60 * 60 * 24 * 365 * 5
) )
@@ -177,9 +187,10 @@ async def login(
key='access_token', key='access_token',
value=dto.access_token, value=dto.access_token,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS), max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
) )
@@ -187,16 +198,18 @@ async def login(
key='refresh_token', key='refresh_token',
value=dto.refresh_token, value=dto.refresh_token,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS), max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
) )
return response return response
@auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip') #@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def logout_current( async def logout_current(
request: Request, request: Request,
command: UserLogoutCommand = Depends(get_user_logout_command), command: UserLogoutCommand = Depends(get_user_logout_command),
@@ -206,8 +219,8 @@ async def logout_current(
await command(refresh_token=refresh_token) await command(refresh_token=refresh_token)
response = ORJSONResponse({'ok': True}) response = ORJSONResponse({'ok': True})
response.delete_cookie('access_token', path='/') response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
response.delete_cookie('refresh_token', path='/') response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
return response return response

View File

@@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter from fastapi import APIRouter, Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from starlette import status from starlette import status
from src.infrastructure.security import CsrfService from src.infrastructure.security import CsrfService
@@ -11,7 +11,7 @@ csrf_router = APIRouter(prefix='/csrf', tags=['csrf'])
@csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) @csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip') @rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def issue_csrf_token(): async def issue_csrf_token(request: Request):
csrf = CsrfService() csrf = CsrfService()
token = csrf.issue() token = csrf.issue()
@@ -30,7 +30,7 @@ async def issue_csrf_token():
httponly=settings.CSRF_COOKIE_HTTPONLY, httponly=settings.CSRF_COOKIE_HTTPONLY,
samesite=settings.CSRF_COOKIE_SAMESITE, samesite=settings.CSRF_COOKIE_SAMESITE,
path=settings.CSRF_COOKIE_PATH, path=settings.CSRF_COOKIE_PATH,
domain=settings.CSRF_COOKIE_DOMAIN, domain=settings.CSRF_COOKIE_DOMAIN or '.elcsa.ru',
max_age=csrf.ttl_seconds, max_age=csrf.ttl_seconds,
) )

View File

@@ -2,63 +2,72 @@ from fastapi import APIRouter, Request, Depends
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from starlette import status from starlette import status
from src.application.commands import JwtRefreshCommand from src.application.commands import JwtRefreshCommand
from src.application.domain.exceptions import ApplicationException from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings from src.infrastructure.config import settings
from src.presentation.decorators import rate_limit from src.presentation.decorators import csrf_protect, rate_limit
from src.presentation.dependencies import get_jwt_refresh_command from src.presentation.dependencies import get_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt']) jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK) def _clear_auth_cookies(response: ORJSONResponse) -> None:
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip') response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
async def refresh_tokens( response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except ApplicationException:
response = ORJSONResponse({'result': False}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
response = ORJSONResponse({'result': True})
def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None:
response.set_cookie( response.set_cookie(
key='access_token', key='access_token',
value=access, value=access,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS), max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
) )
response.set_cookie( response.set_cookie(
key='refresh_token', key='refresh_token',
value=refresh, value=refresh,
httponly=True, httponly=True,
secure=True, secure=settings.AUTH_COOKIE_SECURE,
samesite='lax', samesite='lax',
path='/', path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS), max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
) )
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command),
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
_clear_auth_cookies(response)
return response return response
# Usage ip = request.client.host if request.client else None
# @jwt_router.get("/test") user_agent = request.headers.get('user-agent')
# async def profile(auth: AuthContext = Depends(require_access_token)):
# return 'ok' try:
tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except RefreshConcurrentException:
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
except ApplicationException as exc:
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
response = ORJSONResponse({'result': False}, status_code=401)
_clear_auth_cookies(response)
return response
raise
access, refresh = tokens
response = ORJSONResponse({'result': True})
_set_auth_cookies(response, access, refresh)
return response

View File

@@ -1 +1,5 @@
from src.presentation.schemas.user import RegistrationStart, RegistrationComplete, UserLogin, LoginStart from src.presentation.schemas.error import ErrorResponse
from src.presentation.schemas.user import RegistrationComplete
from src.presentation.schemas.user import RegistrationStart
from src.presentation.schemas.user import LoginStart
from src.presentation.schemas.user import UserLogin

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
from pydantic import Field
class ErrorResponse(BaseModel):
detail: str = Field(title='Detail')

1149
uv.lock generated

File diff suppressed because it is too large Load Diff