21 Commits

Author SHA1 Message Date
e9b99535b9 feat: update auth for b2b 2026-06-02 23:43:44 +03:00
caf7f003fa feat: add validation 2026-05-19 22:29:02 +03:00
666f2f67cb feat: more workers 2026-05-14 23:46:54 +03:00
d3fedc3f91 feat: add new column and delete old 2026-05-13 11:19:34 +03:00
0cca0fedbe fix: add await 2026-05-12 22:47:50 +03:00
9166b21249 feat: update domains 2026-05-12 22:11:00 +03:00
603efa55e6 fix: crsf cookie 2026-05-12 22:03:02 +03:00
cec8d896b6 feat: update users 2026-05-12 21:44:18 +03:00
2d9f44979c feat: add domain in cookie and cors 2026-05-12 19:25:16 +03:00
8347ff40f4 fix: delete bank data and add passport data 2026-05-12 18:07:43 +03:00
4cfab85812 fix: rename path login 2026-05-12 12:33:57 +03:00
9d56b7f6f5 feat: add custom exceptions 2026-05-09 16:25:31 +03:00
bedce9e910 fix: add debug logger 2026-05-08 23:10:16 +03:00
1724d4e37d fix: off rate-limiting 2026-05-08 22:40:26 +03:00
57bafec204 feat: add csrf 2026-04-17 13:05:11 +03:00
0130912555 feat: add csrf 2026-04-17 12:02:26 +03:00
54ebcaeb81 fix: add request in csrf endpoint 2026-04-16 08:09:24 +03:00
3e3b9eb030 fix: delete origins 2026-04-15 15:13:08 +03:00
f92eadf8fa feat: add email code log 2026-04-13 13:35:32 +03:00
949b57e425 feat: add redis password 2026-04-12 16:20:07 +03:00
3fc1b455d2 chore: delete docker-compose 2026-04-12 14:18:09 +03:00
41 changed files with 1153 additions and 710 deletions

2
.gitignore vendored
View File

@@ -2,7 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
generate_password_hash.py
# C extensions
*.so
*.pyd

View File

@@ -1,83 +0,0 @@
services:
auth:
container_name: auth-service
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
PYTHONUNBUFFERED: "1"
APP_MODULE: "src.main:app"
APP_HOST: "0.0.0.0"
APP_PORT: "8000"
APP_WORKERS: "1"
env_file:
- .env
depends_on:
keydb:
condition: service_healthy
restart: no
keydb:
image: eqalpha/keydb
container_name: keydb
restart: no
expose:
- "6379"
volumes:
- keydb_data:/data
command:
- keydb-server
- --requirepass
- keydb
- --dir
- /data
- --appendonly
- "yes"
- --appendfsync
- everysec
- --save
- "900"
- "1"
- --save
- "300"
- "10"
- --save
- "60"
- "10000"
healthcheck:
test: [ "CMD", "redis-cli", "-a", "keydb", "ping" ]
interval: 5s
timeout: 2s
retries: 20
# keydb:
# image: eqalpha/keydb
# container_name: keydb
# restart: no
# expose:
# - "6379"
# volumes:
# - keydb_data:/data
# environment:
# KEYDB_PASSWORD: keydb
# command: >
# sh -c "
# keydb-server
# --requirepass $$KEYDB_PASSWORD
# --dir /data
# --appendonly yes
# --appendfsync everysec
# --save 900 1
# --save 300 10
# --save 60 10000
# "
# healthcheck:
# test: ["CMD", "redis-cli", "ping"]
# interval: 5s
# timeout: 2s
# retries: 20
volumes:
keydb_data:

View File

@@ -4,6 +4,8 @@ version = "0.1.0"
description = "Add your description here"
requires-python = "==3.12.*"
dependencies = [
"acryl-datahub>=1.5.0.19",
"acryl-sqlglot>=25.25.2.dev9",
"apscheduler==3.11.2",
"asyncpg==0.31.0",
"bcrypt==5.0.0",

View File

@@ -1,18 +1,32 @@
import asyncio
from datetime import datetime, timezone, timedelta
from ulid import ULID
from src.application.abstractions import IUnitOfWork
from src.application.contracts import IHashService, IJwtService, ILogger
from src.application.contracts import IHashService, IJwtService, ILogger, ICache
from src.application.domain.dto import RefreshTokenPayload
from src.application.domain.exceptions import ApplicationException
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.infrastructure.database.decorators import transactional
class JwtRefreshCommand:
def __init__(self, unit_of_work: IUnitOfWork, hash_service: IHashService, jwt_service: IJwtService, logger: ILogger):
_LOCK_PREFIX = 'jwt:refresh:lock:'
_LOCK_TTL_SECONDS = 15
_LOCK_WAIT_ATTEMPTS = 40
_LOCK_WAIT_INTERVAL_SECONDS = 0.05
def __init__(
self,
unit_of_work: IUnitOfWork,
hash_service: IHashService,
jwt_service: IJwtService,
cache: ICache,
logger: ILogger,
):
self._unit_of_work = unit_of_work
self._hash_service = hash_service
self._jwt_service = jwt_service
self._cache = cache
self._logger = logger
@transactional
@@ -25,6 +39,39 @@ class JwtRefreshCommand:
user_id = payload.sub
jti = payload.jti
lock_key = f'{self._LOCK_PREFIX}{sid}'
locked = await self._cache.set_nx(lock_key, '1', self._LOCK_TTL_SECONDS)
if not locked:
for _ in range(self._LOCK_WAIT_ATTEMPTS):
await asyncio.sleep(self._LOCK_WAIT_INTERVAL_SECONDS)
if await self._cache.get(lock_key) is None:
self._logger.info(f'Concurrent refresh skipped (sid={sid})')
raise RefreshConcurrentException()
raise ApplicationException(status_code=429, message='Refresh in progress')
try:
return await self._refresh_locked(
sid=sid,
user_id=user_id,
jti=jti,
now=now,
ip=ip,
user_agent=user_agent,
)
finally:
await self._cache.delete(lock_key)
async def _refresh_locked(
self,
*,
sid: str,
user_id: str,
jti: str,
now: datetime,
ip: str | None,
user_agent: str | None,
) -> tuple[str, str]:
sess = await self._unit_of_work.session_repository.get_by_sid(sid)
if sess is None:
raise ApplicationException(status_code=401, message='Session not found')
@@ -61,7 +108,8 @@ class JwtRefreshCommand:
)
if not rotated:
raise ApplicationException(status_code=401, message='Refresh already rotated')
self._logger.info(f'Refresh already rotated (sid={sid})')
raise RefreshConcurrentException()
access = await self._jwt_service.create_access_token(user_id=user_id, sid=sid)
refresh = await self._jwt_service.create_refresh_token(user_id=user_id, sid=sid, refresh_jti=new_jti)

View File

@@ -102,12 +102,12 @@ class UserLoginCompleteCommand:
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet,
encrypted_mnemonic=user.encrypted_mnemonic,
phone=user.phone,
bik=user.bik,
account_number=user.account_number,
card_number=user.card_number,
passport_data=user.passport_data,
inn=user.inn,
erc20=user.erc20,
avatar_link=user.avatar_link,
kyc_verified=user.kyc_verified,
kyc_verified_at=user.kyc_verified_at,
created_at=user.created_at,

View File

@@ -102,8 +102,6 @@ class UserLoginStartCommand:
'metadata': metadata,
}
self._logger.info(f'payload: {payload})')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
@@ -123,7 +121,7 @@ class UserLoginStartCommand:
self._logger.error(f'Failed to publish login email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'login code created for {email}')
self._logger.info(f'Login email verification code queued email={email}')
return True
self._logger.error(f'login start failed: code space exhausted for {email}')

View File

@@ -18,7 +18,7 @@ class UserLogoutCommand:
if not refresh_token:
return
try:
payload: RefreshTokenPayload = self._jwt_service.decode_refresh_token(refresh_token)
payload: RefreshTokenPayload = await self._jwt_service.decode_refresh_token(refresh_token)
except ApplicationException:
self._logger.debug('Logout: refresh token invalid/expired, skipping revoke')
return

View File

@@ -45,7 +45,7 @@ class UserRegistrationCompleteCommand:
cached_email = await self._cache.get(code_key)
if not cached_email:
self._logger.info(f'Registration failed: code not found (email={email}, code={code})')
self._logger.info(f'Registration failed: code not found (email={email})')
raise ApplicationException(400, 'Invalid or expired code')
if cached_email != email:

View File

@@ -102,8 +102,6 @@ class UserRegistrationStartCommand:
'metadata': metadata,
}
self._logger.info(f'payload: {payload})')
try:
await self._messanger.publish_to_queue(
queue=settings.RABBIT_EMAIL_CODE_QUEUE,
@@ -123,7 +121,7 @@ class UserRegistrationStartCommand:
self._logger.error(f'Failed to publish registration email event for {email}: {str(exception)}')
raise ApplicationException(503, 'Temporary error. Please try again.')
self._logger.info(f'Registration code created for {email}')
self._logger.info(f'Registration email verification code queued email={email}')
return True
self._logger.error(f'Registration start failed: code space exhausted for {email}')

View File

@@ -18,12 +18,12 @@ class UserLoginDto:
middle_name: str | None = None
last_name: str | None = None
birth_date: date | None = None
crypto_wallet: str | None = None
encrypted_mnemonic: str | None = None
phone: str | None = None
bik: str | None = None
account_number: str | None = None
card_number: str | None = None
passport_data: str | None = None
inn: str | None = None
erc20: str | None = None
avatar_link: str | None = None
kyc_verified: bool | None = None
access_token: str | None = None
refresh_token: str | None = None

View File

@@ -14,13 +14,14 @@ class UserEntity:
last_name: str | None = None
birth_date: date | None = None
crypto_wallet: str | None = None
encrypted_mnemonic: str | None = None
phone: str | None = None
bik: str | None = None
account_number: str | None = None
card_number: str | None = None
passport_data: str | None = None
inn: str | None = None
erc20: str | None = None
avatar_link: str | None = None
kyc_verified: bool | None = None
is_deleted: bool | None = None
@@ -28,3 +29,7 @@ class UserEntity:
created_at: datetime | None = None
updated_at: datetime | None = None
kyc_verified_at: datetime | None = None
account_type: str = 'individual'
provisioned_by: str | None = None
provisioned_at: datetime | None = None

View File

@@ -0,0 +1,6 @@
from enum import StrEnum
class AccountType(StrEnum):
INDIVIDUAL = 'individual'
LEGAL_ENTITY = 'legal_entity'

View File

@@ -1 +1,10 @@
from src.application.domain.exceptions.application_exceptions import ApplicationException
from src.application.domain.exceptions.application_exception import ApplicationException
from src.application.domain.exceptions.bad_request_exception import BadRequestException
from src.application.domain.exceptions.conflict_exception import ConflictException
from src.application.domain.exceptions.forbidden_exception import ForbiddenException
from src.application.domain.exceptions.internal_server_exception import InternalServerException
from src.application.domain.exceptions.not_found_exception import NotFoundException
from src.application.domain.exceptions.service_unavailable_exception import ServiceUnavailableException
from src.application.domain.exceptions.too_many_requests_exception import TooManyRequestsException
from src.application.domain.exceptions.unauthorized_exception import UnauthorizedException
from src.application.domain.exceptions.refresh_concurrent_exception import RefreshConcurrentException

View File

@@ -15,4 +15,4 @@ class ApplicationException(Exception):
self.headers = headers
def __str__(self):
return f"{self.status_code}: {self.message}"
return f'{self.status_code}: {self.message}'

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class BadRequestException(ApplicationException):
def __init__(
self,
message: str = 'Bad Request',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ConflictException(ApplicationException):
def __init__(
self,
message: str = 'Conflict',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_409_CONFLICT,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ForbiddenException(ApplicationException):
def __init__(
self,
message: str = 'Forbidden',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class InternalServerException(ApplicationException):
def __init__(
self,
message: str = 'Internal Server Error',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class NotFoundException(ApplicationException):
def __init__(
self,
message: str = 'Not Found',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,10 @@
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class RefreshConcurrentException(ApplicationException):
def __init__(self) -> None:
super().__init__(
status_code=status.HTTP_200_OK,
message='Refresh already handled',
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class ServiceUnavailableException(ApplicationException):
def __init__(
self,
message: str = 'Service Unavailable',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class TooManyRequestsException(ApplicationException):
def __init__(
self,
message: str = 'Too Many Requests',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
message=message,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
from typing import Mapping
from starlette import status
from src.application.domain.exceptions.application_exception import ApplicationException
class UnauthorizedException(ApplicationException):
def __init__(
self,
message: str = 'Unauthorized',
headers: Mapping[str, str] | None = None,
):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
message=message,
headers=headers,
)

View File

@@ -4,13 +4,15 @@ from src.infrastructure.config import settings
def create_redis_client() -> Redis:
return redis.from_url(
settings.REDIS_URL,
max_connections=50,
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
retry_on_timeout=True,
socket_keepalive=True,
)
kw = {
'max_connections': 50,
'decode_responses': True,
'socket_timeout': 5,
'socket_connect_timeout': 5,
'health_check_interval': 30,
'retry_on_timeout': True,
'socket_keepalive': True,
}
if settings.REDIS_PASSWORD:
kw['password'] = settings.REDIS_PASSWORD
return redis.from_url(settings.REDIS_URL, **kw)

View File

@@ -55,7 +55,11 @@ class Settings(BaseSettings):
CSRF_COOKIE_HTTPONLY: bool = True
CSRF_COOKIE_SAMESITE: Literal['Lax', 'Strict', 'None'] = 'Lax'
CSRF_COOKIE_PATH: str = '/'
CSRF_COOKIE_DOMAIN: str | None = None
CSRF_COOKIE_DOMAIN: str | None = '.elcsa.ru'
AUTH_COOKIE_SECURE: bool = False
AUTH_COOKIE_DOMAIN: str | None = '.elcsa.ru'
CORS_ALLOW_ORIGIN_REGEX: str = r'https?://([a-z0-9-]+\.)*elcsa\.ru(:\d+)?$'
DOCS_USERNAME: str = 'admin'
DOCS_PASSWORD: str = 'admin'
@@ -81,9 +85,6 @@ class Settings(BaseSettings):
RABBIT_CONNECT_TIMEOUT: int = 5
RABBIT_EMAIL_CODE_QUEUE: str = 'email.verification_code'
CORS_ORIGINS: str = 'http://localhost:3000'
CORS_ALLOW_CREDENTIALS: bool = True
RATE_LIMIT_REQUESTS: int = 60
RATE_LIMIT_WINDOW: int = 60
@@ -99,7 +100,33 @@ class Settings(BaseSettings):
@field_validator('CSRF_COOKIE_DOMAIN', mode='before')
@classmethod
def empty_csrf_domain_to_none(cls, v):
def normalize_csrf_cookie_domain(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return '.elcsa.ru'
s = str(v).strip()
sl = s.lower()
if sl in ('.elcsa.ru', 'elcsa.ru'):
return '.elcsa.ru'
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
return '.elcsa.ru'
return s
@field_validator('AUTH_COOKIE_DOMAIN', mode='before')
@classmethod
def normalize_auth_cookie_domain(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return '.elcsa.ru'
s = str(v).strip()
sl = s.lower()
if sl in ('.elcsa.ru', 'elcsa.ru'):
return '.elcsa.ru'
if sl.endswith('.elcsa.ru') and not sl.startswith('.'):
return '.elcsa.ru'
return s
@field_validator('REDIS_PASSWORD', mode='before')
@classmethod
def empty_redis_password_to_none(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
return None
return v
@@ -215,13 +242,29 @@ class Settings(BaseSettings):
rb_set('password', 'RABBIT_PASSWORD')
rb_set('vhost', 'RABBIT_VHOST')
redis_secret = read_secret_optional('redis')
if redis_secret:
rd_ci = {str(k).lower(): v for k, v in redis_secret.items()}
def rd_set(field: str, env_key: str, *, as_int: bool = False) -> None:
v = rd_ci.get(field)
if v is None:
return
if isinstance(v, str) and not v.strip():
return
if as_int:
data[env_key] = int(v)
else:
data[env_key] = str(v).strip()
rd_set('host', 'REDIS_HOST')
rd_set('port', 'REDIS_PORT', as_int=True)
rd_set('password', 'REDIS_PASSWORD')
rd_set('db', 'REDIS_DB', as_int=True)
return data
def cors_origins_list(self) -> List[str]:
return [o.strip() for o in self.CORS_ORIGINS.split(',') if o.strip()]
@property
def DATABASE_URL(self) -> str:
return (
@@ -231,8 +274,7 @@ class Settings(BaseSettings):
@property
def REDIS_URL(self) -> str:
auth = f":{self.REDIS_PASSWORD}@" if self.REDIS_PASSWORD else ""
return f"redis://{auth}{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}"
return f'redis://{self.REDIS_HOST}:{self.REDIS_PORT}/{self.REDIS_DB}'
@property
def RABBIT_URL(self) -> str:

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
from sqlalchemy import Boolean, Date, String, DateTime
from sqlalchemy.orm import Mapped, mapped_column
from datetime import datetime
from sqlalchemy import Boolean, Date, String, DateTime, Text
from sqlalchemy.orm import Mapped,mapped_column
from src.infrastructure.database.models.base import Base
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
@@ -16,13 +19,18 @@ class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
encrypted_mnemonic: Mapped[str | None] = mapped_column(Text, nullable=True)
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
bik: Mapped[str | None] = mapped_column(String(9), nullable=True)
account_number: Mapped[str | None] = mapped_column(String(20), nullable=True)
card_number: Mapped[str | None] = mapped_column(String(19), nullable=True)
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
avatar_link: Mapped[str | None] = mapped_column(Text, nullable=True)
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
account_type: Mapped[str] = mapped_column(String(20), nullable=False, server_default='individual', default='individual')
provisioned_by: Mapped[str | None] = mapped_column(String(26), nullable=True)
provisioned_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View File

@@ -7,6 +7,7 @@ from src.application.contracts import ILogger
from src.application.domain.exceptions import ApplicationException
from src.application.abstractions.repositories import IUserRepository
from src.application.domain.entities import UserEntity
from src.application.domain.enums.account_type import AccountType
from src.infrastructure.database.models import UserModel
@@ -16,7 +17,11 @@ class UserRepository(IUserRepository):
self._logger = logger
async def create_user(self, email: str, password_hash: str) -> UserEntity:
user = UserModel(email=email, password_hash=password_hash)
user = UserModel(
email=email,
password_hash=password_hash,
account_type=AccountType.INDIVIDUAL,
)
self._session.add(user)
try:
await self._session.flush()
@@ -25,7 +30,9 @@ class UserRepository(IUserRepository):
email=user.email,
created_at=user.created_at,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted
is_deleted=user.is_deleted,
avatar_link=user.avatar_link,
account_type=user.account_type,
)
except IntegrityError:
@@ -67,17 +74,20 @@ class UserRepository(IUserRepository):
middle_name=user.middle_name,
last_name=user.last_name,
birth_date=user.birth_date,
crypto_wallet=user.crypto_wallet,
encrypted_mnemonic=user.encrypted_mnemonic,
phone=user.phone,
bik=user.bik,
account_number=user.account_number,
card_number=user.card_number,
passport_data=user.passport_data,
inn=user.inn,
kyc_verified_at=user.kyc_verified_at,
erc20=user.erc20,
avatar_link=user.avatar_link,
kyc_verified=user.kyc_verified,
is_deleted=user.is_deleted,
created_at=user.created_at,
updated_at=user.updated_at,
kyc_verified_at=user.kyc_verified_at,
account_type=user.account_type,
provisioned_by=user.provisioned_by,
provisioned_at=user.provisioned_at,
)
except ApplicationException:

View File

@@ -2,6 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
from src.application.contracts import ILogger
from src.application.domain.exceptions import RefreshConcurrentException
from src.infrastructure.database.repositories import UserRepository, SessionRepository
@@ -20,9 +21,11 @@ class UnitOfWork(IUnitOfWork):
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._logger.error(str(exc_val))
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(str(exc_val))
await self._session.rollback()
self._logger.error(f'Rollback: str{exc_val})')
if not isinstance(exc_val, RefreshConcurrentException):
self._logger.error(f'Rollback: str{exc_val})')
else:
await self._session.flush()
await self._session.commit()

View File

@@ -4,6 +4,7 @@ import secrets
from typing import AsyncGenerator
from fastapi import Depends, FastAPI, status
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.cors import CORSMiddleware
@@ -15,13 +16,66 @@ from src.infrastructure.utils import generate_instance_id
from src.infrastructure.logger import logger
from src.infrastructure.config import settings
from src.presentation.dependencies import get_rabbit
from src.presentation.handlers import application_exception_handler, unhandled_exception_handler
from src.presentation.handler import application_exception_handler
from src.presentation.handler import unhandled_exception_handler
from src.presentation.middleware import TraceIDMiddleware, SecurityHeadersMiddleware
from src.presentation.routing import v1_router
from src.presentation.schemas import ErrorResponse
security = HTTPBasic()
ERROR_RESPONSES: dict[int, str] = {
status.HTTP_400_BAD_REQUEST: 'Bad Request',
status.HTTP_401_UNAUTHORIZED: 'Unauthorized',
status.HTTP_403_FORBIDDEN: 'Forbidden',
status.HTTP_404_NOT_FOUND: 'Not Found',
status.HTTP_409_CONFLICT: 'Conflict',
status.HTTP_429_TOO_MANY_REQUESTS: 'Too Many Requests',
status.HTTP_500_INTERNAL_SERVER_ERROR: 'Internal Server Error',
status.HTTP_503_SERVICE_UNAVAILABLE: 'Service Unavailable',
}
def custom_openapi() -> dict:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
license_info=app.license_info,
)
components = openapi_schema.setdefault('components', {})
schemas = components.setdefault('schemas', {})
schemas['ErrorResponse'] = ErrorResponse.model_json_schema()
for path_item in openapi_schema.get('paths', {}).values():
for operation in path_item.values():
if not isinstance(operation, dict):
continue
responses = operation.setdefault('responses', {})
for status_code, description in ERROR_RESPONSES.items():
responses.setdefault(
str(status_code),
{
'description': description,
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/ErrorResponse',
},
},
},
},
)
app.openapi_schema = openapi_schema
return app.openapi_schema
async def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> HTTPBasicCredentials:
user_ok = secrets.compare_digest(credentials.username, settings.DOCS_USERNAME)
@@ -117,8 +171,8 @@ app.add_middleware(
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list(),
allow_credentials=settings.CORS_ALLOW_CREDENTIALS,
allow_origin_regex=settings.CORS_ALLOW_ORIGIN_REGEX,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
@@ -152,3 +206,6 @@ async def ping() -> dict[str, str]:
'message': 'pong',
'status': 'ok',
}
app.openapi = custom_openapi

View File

@@ -29,7 +29,7 @@ async def require_access_token(
if not token:
raise ApplicationException(status_code=401, message="Not authenticated")
payload = jwt_service.decode_access_token(token)
payload = await jwt_service.decode_access_token(token)
if payload.type != "access":
raise ApplicationException(status_code=401, message="Invalid token type")

View File

@@ -93,6 +93,7 @@ def get_jwt_refresh_command(
uow: IUnitOfWork = Depends(get_unit_of_work),
hash_service: IHashService = Depends(get_hash_service),
jwt_service: IJwtService = Depends(get_jwt_service),
cache: ICache = Depends(get_cache),
logger: ILogger = Depends(get_logger),
) -> JwtRefreshCommand:
return JwtRefreshCommand(uow, hash_service, jwt_service, logger)
return JwtRefreshCommand(uow, hash_service, jwt_service, cache, logger)

View File

@@ -0,0 +1,2 @@
from src.presentation.handler.application_exception_handler import application_exception_handler
from src.presentation.handler.unhandled_exception_handler import unhandled_exception_handler

View File

@@ -1,17 +1,15 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from fastapi.responses import ORJSONResponse
from src.application.domain.exceptions import ApplicationException
async def application_exception_handler(_request: Request, exc: ApplicationException) -> ORJSONResponse:
detail = exc.message
if 500 <= exc.status_code:
detail = "Internal Server Error"
detail = 'Internal Server Error'
return ORJSONResponse(
status_code=exc.status_code,
content={"detail": detail},
content={'detail': detail},
headers=dict(exc.headers) if exc.headers else None,
)

View File

@@ -1,5 +1,5 @@
from fastapi.responses import ORJSONResponse
from fastapi import Request
from fastapi.responses import ORJSONResponse
from starlette import status
from src.infrastructure.logger import logger
@@ -9,4 +9,4 @@ async def unhandled_exception_handler(_request: Request, exc: Exception) -> ORJS
return ORJSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={'detail': 'Internal Server Error'},
)
)

View File

@@ -1,2 +0,0 @@
from src.presentation.handlers.unhandled_handler import unhandled_exception_handler
from src.presentation.handlers.application_handler import application_exception_handler

View File

@@ -12,7 +12,7 @@ from src.application.contracts import ILogger
from src.application.domain.dto import UserLoginDto
from src.infrastructure.config import settings
from src.infrastructure.logger import get_logger
from src.presentation.decorators import rate_limit, email_rl_key
from src.presentation.decorators import csrf_protect,rate_limit,email_rl_key
from src.presentation.dependencies import (
get_user_registration_complete_command,
get_user_logout_command,
@@ -31,19 +31,23 @@ auth_router = APIRouter(prefix='/auth', tags=['auth'])
response_class=ORJSONResponse,
status_code=status.HTTP_200_OK,
)
@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
#@rate_limit(limit=5, window_seconds=60, scope='ip')
#@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:reg_start', key_builder=email_rl_key)
@csrf_protect()
async def registration_start(
request: Request,
body: RegistrationStart,
logger: ILogger = Depends(get_logger),
command: UserRegistrationStartCommand = Depends(get_user_registration_start_command),
):
logger.info('AHAHAHAHAHAHHAHAAH')
result = await command(body.email)
return {'success': result}
@auth_router.post(path='/registration/complete', response_class=ORJSONResponse, status_code=status.HTTP_201_CREATED)
@rate_limit(limit=10, window_seconds=300, scope='ip')
#@rate_limit(limit=10, window_seconds=300, scope='ip')
@csrf_protect()
async def registration(
request: Request,
user: RegistrationComplete,
@@ -76,9 +80,10 @@ async def registration(
key='device_id',
value=device_id,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=60 * 60 * 24 * 365 * 5
)
@@ -86,25 +91,28 @@ async def registration(
key='access_token',
value=created.access_token,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
response.set_cookie(
key='refresh_token',
value=created.refresh_token,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
@auth_router.post(path='/login/start', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=5, window_seconds=60, scope='ip')
@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
#@rate_limit(limit=5, window_seconds=60, scope='ip')
#@rate_limit(limit=3, window_seconds=600, scope='key', key_prefix='rl:login_start', key_builder=email_rl_key)
@csrf_protect()
async def login_start(
request: Request,
body: LoginStart,
@@ -114,8 +122,9 @@ async def login_start(
return {'success': result}
@auth_router.post(path='/login/compete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=10, window_seconds=300, scope='ip')
@auth_router.post(path='/login/complete', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
#@rate_limit(limit=10, window_seconds=300, scope='ip')
@csrf_protect()
async def login(
request: Request,
user: UserLogin,
@@ -150,12 +159,12 @@ async def login(
'middle_name': dto.middle_name,
'last_name': dto.last_name,
'birth_date': dto.birth_date.isoformat() if dto.birth_date else None,
'crypto_wallet': dto.crypto_wallet,
'encrypted_mnemonic': dto.encrypted_mnemonic,
'phone': dto.phone,
'bik': dto.bik,
'account_number': dto.account_number,
'card_number': dto.card_number,
'passport_data': dto.passport_data,
'inn': dto.inn,
'erc20': dto.erc20,
'avatar_link': dto.avatar_link,
'kyc_verified': dto.kyc_verified,
'kyc_verified_at': dto.kyc_verified_at,
'created_at': dto.created_at.isoformat() if dto.created_at else None,
@@ -167,9 +176,10 @@ async def login(
key='device_id',
value=device_id,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=60 * 60 * 24 * 365 * 5
)
@@ -177,9 +187,10 @@ async def login(
key='access_token',
value=dto.access_token,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
@@ -187,16 +198,18 @@ async def login(
key='refresh_token',
value=dto.refresh_token,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
@auth_router.post(path='/logout', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
#@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def logout_current(
request: Request,
command: UserLogoutCommand = Depends(get_user_logout_command),
@@ -206,8 +219,8 @@ async def logout_current(
await command(refresh_token=refresh_token)
response = ORJSONResponse({'ok': True})
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
return response

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from fastapi import APIRouter
from fastapi import APIRouter, Request
from fastapi.responses import ORJSONResponse
from starlette import status
from src.infrastructure.security import CsrfService
@@ -11,7 +11,7 @@ csrf_router = APIRouter(prefix='/csrf', tags=['csrf'])
@csrf_router.get('/token', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def issue_csrf_token():
async def issue_csrf_token(request: Request):
csrf = CsrfService()
token = csrf.issue()
@@ -30,7 +30,7 @@ async def issue_csrf_token():
httponly=settings.CSRF_COOKIE_HTTPONLY,
samesite=settings.CSRF_COOKIE_SAMESITE,
path=settings.CSRF_COOKIE_PATH,
domain=settings.CSRF_COOKIE_DOMAIN,
domain=settings.CSRF_COOKIE_DOMAIN or '.elcsa.ru',
max_age=csrf.ttl_seconds,
)

View File

@@ -2,63 +2,72 @@ from fastapi import APIRouter, Request, Depends
from fastapi.responses import ORJSONResponse
from starlette import status
from src.application.commands import JwtRefreshCommand
from src.application.domain.exceptions import ApplicationException
from src.application.domain.exceptions import ApplicationException, RefreshConcurrentException
from src.infrastructure.config import settings
from src.presentation.decorators import rate_limit
from src.presentation.decorators import csrf_protect, rate_limit
from src.presentation.dependencies import get_jwt_refresh_command
jwt_router = APIRouter(prefix='/jwt', tags=['Jwt'])
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command)
):
refresh_token = request.cookies.get('refresh_token')
def _clear_auth_cookies(response: ORJSONResponse) -> None:
response.delete_cookie('access_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
response.delete_cookie('refresh_token', path='/', domain=settings.AUTH_COOKIE_DOMAIN)
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
access, refresh = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except ApplicationException:
response = ORJSONResponse({'result': False}, status_code=401)
response.delete_cookie('access_token', path='/')
response.delete_cookie('refresh_token', path='/')
return response
response = ORJSONResponse({'result': True})
def _set_auth_cookies(response: ORJSONResponse, access: str, refresh: str) -> None:
response.set_cookie(
key='access_token',
value=access,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_ACCESS_TTL_SECONDS),
)
response.set_cookie(
key='refresh_token',
value=refresh,
httponly=True,
secure=True,
secure=settings.AUTH_COOKIE_SECURE,
samesite='lax',
path='/',
domain=settings.AUTH_COOKIE_DOMAIN,
max_age=int(settings.JWT_REFRESH_TTL_SECONDS),
)
return response
# Usage
# @jwt_router.get("/test")
# async def profile(auth: AuthContext = Depends(require_access_token)):
# return 'ok'
@jwt_router.post('/refresh', response_class=ORJSONResponse, status_code=status.HTTP_200_OK)
@rate_limit(limit=settings.RATE_LIMIT_REQUESTS, window_seconds=settings.RATE_LIMIT_WINDOW, scope='ip')
@csrf_protect()
async def refresh_tokens(
request: Request,
command: JwtRefreshCommand = Depends(get_jwt_refresh_command),
):
refresh_token = request.cookies.get('refresh_token')
if not refresh_token:
response = ORJSONResponse({'ok': False, 'error': 'No refresh token'}, status_code=401)
_clear_auth_cookies(response)
return response
ip = request.client.host if request.client else None
user_agent = request.headers.get('user-agent')
try:
tokens = await command(refresh_token=refresh_token, ip=ip, user_agent=user_agent)
except RefreshConcurrentException:
return ORJSONResponse({'result': True, 'concurrent': True}, status_code=status.HTTP_200_OK)
except ApplicationException as exc:
if exc.status_code == status.HTTP_401_UNAUTHORIZED:
response = ORJSONResponse({'result': False}, status_code=401)
_clear_auth_cookies(response)
return response
raise
access, refresh = tokens
response = ORJSONResponse({'result': True})
_set_auth_cookies(response, access, refresh)
return response

View File

@@ -1 +1,5 @@
from src.presentation.schemas.user import RegistrationStart, RegistrationComplete, UserLogin, LoginStart
from src.presentation.schemas.error import ErrorResponse
from src.presentation.schemas.user import RegistrationComplete
from src.presentation.schemas.user import RegistrationStart
from src.presentation.schemas.user import LoginStart
from src.presentation.schemas.user import UserLogin

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
from pydantic import Field
class ErrorResponse(BaseModel):
detail: str = Field(title='Detail')

1149
uv.lock generated

File diff suppressed because it is too large Load Diff