feat(account): GET /me user endpoint only, disable cache and extra routers
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
1
src/infrastructure/database/__init__.py
Normal file
1
src/infrastructure/database/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.unit_of_work import UnitOfWork
|
||||
22
src/infrastructure/database/context.py
Normal file
22
src/infrastructure/database/context.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio.engine import create_async_engine
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from typing import AsyncGenerator
|
||||
from src.infrastructure.config import settings
|
||||
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DATABASE_POOL_SIZE,
|
||||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DATABASE_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DATABASE_POOL_RECYCLE,
|
||||
echo=settings.DATABASE_ECHO
|
||||
)
|
||||
|
||||
async_session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
1
src/infrastructure/database/decorators/__init__.py
Normal file
1
src/infrastructure/database/decorators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from src.infrastructure.database.decorators.transactional import transactional
|
||||
15
src/infrastructure/database/decorators/transactional.py
Normal file
15
src/infrastructure/database/decorators/transactional.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
from typing import Callable, Awaitable, TypeVar, ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def transactional(method: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
@wraps(method)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
async with self._unit_of_work:
|
||||
return await method(self, *args, **kwargs)
|
||||
return wrapper
|
||||
6
src/infrastructure/database/models/__init__.py
Normal file
6
src/infrastructure/database/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.user import UserModel
|
||||
from src.infrastructure.database.models.sessions import Session
|
||||
|
||||
__all__ = ['Base', 'UserModel', 'Session']
|
||||
|
||||
19
src/infrastructure/database/models/base.py
Normal file
19
src/infrastructure/database/models/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(AsyncAttrs, DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name, None)!r}"
|
||||
for col in self.__table__.columns)
|
||||
return f"<{class_name}({attributes})>"
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
attributes = ', '.join(f"{col.name}={getattr(self, col.name)}"
|
||||
for col in self.__table__.columns
|
||||
if getattr(self, col.name) is not None)
|
||||
return f"{class_name}({attributes})"
|
||||
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
3
src/infrastructure/database/models/mixins/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.infrastructure.database.models.mixins.audit import AuditTimestampsMixin
|
||||
from src.infrastructure.database.models.mixins.ulid import UlidPrimaryKeyMixin
|
||||
from src.infrastructure.database.models.mixins.soft_delete import SoftDeleteMixin
|
||||
16
src/infrastructure/database/models/mixins/audit.py
Normal file
16
src/infrastructure/database/models/mixins/audit.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class AuditTimestampsMixin:
|
||||
created_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
updated_at: Mapped[DateTime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
6
src/infrastructure/database/models/mixins/soft_delete.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
8
src/infrastructure/database/models/mixins/ulid.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class UlidPrimaryKeyMixin:
|
||||
|
||||
id: Mapped[str] = mapped_column(String(26), primary_key=True, default=lambda: str(ULID()))
|
||||
50
src/infrastructure/database/models/sessions.py
Normal file
50
src/infrastructure/database/models/sessions.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import String, DateTime, ForeignKey, Index
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from ulid import ULID
|
||||
from src.infrastructure.database.models import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin, AuditTimestampsMixin
|
||||
|
||||
|
||||
class Session(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
sid: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
default=lambda: str(ULID()),
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
index=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
device_id: Mapped[str] = mapped_column(
|
||||
String(26),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
user_agent: Mapped[str | None] = mapped_column(String(500))
|
||||
first_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
last_ip: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
last_seen_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
refresh_jti_hash: Mapped[str | None] = mapped_column(String(255))
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
|
||||
Index("ux_sessions_user_device", Session.user_id, Session.device_id, unique=True)
|
||||
Index("ix_sessions_user_active", Session.user_id, Session.revoked_at)
|
||||
28
src/infrastructure/database/models/user.py
Normal file
28
src/infrastructure/database/models/user.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Boolean,Date,String,DateTime
|
||||
from sqlalchemy.orm import Mapped,mapped_column
|
||||
from src.infrastructure.database.models.base import Base
|
||||
from src.infrastructure.database.models.mixins import UlidPrimaryKeyMixin,AuditTimestampsMixin,SoftDeleteMixin
|
||||
|
||||
|
||||
class UserModel(Base, UlidPrimaryKeyMixin, AuditTimestampsMixin, SoftDeleteMixin):
|
||||
__tablename__ = 'users'
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
last_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
first_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
middle_name: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
birth_date: Mapped[Date | None] = mapped_column(Date, nullable=True)
|
||||
|
||||
crypto_wallet: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
phone: Mapped[str | None] = mapped_column(String(16), nullable=True)
|
||||
|
||||
passport_data: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
inn: Mapped[str | None] = mapped_column(String(12), nullable=True)
|
||||
erc20: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
|
||||
kyc_verified: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default='false', default=False)
|
||||
kyc_verified_at: Mapped[DateTime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
2
src/infrastructure/database/repositories/__init__.py
Normal file
2
src/infrastructure/database/repositories/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from src.infrastructure.database.repositories.user_repository import UserRepository
|
||||
from src.infrastructure.database.repositories.session_repository import SessionRepository
|
||||
198
src/infrastructure/database/repositories/session_repository.py
Normal file
198
src/infrastructure/database/repositories/session_repository.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.entities import SessionEntity
|
||||
from src.application.abstractions.repositories import ISessionRepository
|
||||
from src.infrastructure.database.models import Session
|
||||
|
||||
|
||||
class SessionRepository(ISessionRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def get_by_sid(self, sid: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(select(Session).where(Session.sid == sid))
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def get_by_user_device(self, user_id: str, device_id: str) -> Optional[SessionEntity]:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
if m is None:
|
||||
return None
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def upsert_by_device(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
sid: str,
|
||||
refresh_jti_hash: str,
|
||||
refresh_expires_at: datetime,
|
||||
user_agent: str | None,
|
||||
ip: str | None,
|
||||
now: datetime,
|
||||
) -> SessionEntity:
|
||||
res = await self._session.execute(
|
||||
select(Session).where(Session.user_id == user_id, Session.device_id == device_id)
|
||||
)
|
||||
m = res.scalar_one_or_none()
|
||||
|
||||
if m is None:
|
||||
m = Session(
|
||||
sid=sid,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
revoked_at=None,
|
||||
last_seen_at=now,
|
||||
refresh_jti_hash=refresh_jti_hash,
|
||||
refresh_expires_at=refresh_expires_at,
|
||||
user_agent=user_agent,
|
||||
first_ip=ip,
|
||||
last_ip=ip,
|
||||
)
|
||||
self._session.add(m)
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session created (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
else:
|
||||
m.sid = sid
|
||||
m.revoked_at = None
|
||||
m.last_seen_at = now
|
||||
m.refresh_jti_hash = refresh_jti_hash
|
||||
m.refresh_expires_at = refresh_expires_at
|
||||
m.user_agent = user_agent
|
||||
m.last_ip = ip
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
self._logger.info(f'Session updated (user_id={user_id}, device_id={device_id}, sid={sid})')
|
||||
|
||||
return SessionEntity(
|
||||
sid=m.sid,
|
||||
user_id=m.user_id,
|
||||
device_id=m.device_id,
|
||||
revoked_at=m.revoked_at,
|
||||
last_seen_at=m.last_seen_at,
|
||||
refresh_jti_hash=m.refresh_jti_hash,
|
||||
refresh_expires_at=m.refresh_expires_at,
|
||||
user_agent=m.user_agent,
|
||||
first_ip=m.first_ip,
|
||||
last_ip=m.last_ip,
|
||||
)
|
||||
|
||||
async def revoke_by_sid(self, sid: str, now: datetime) -> None:
|
||||
# Интерфейс требует None -> просто делаем update и flush
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(revoked_at=now)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh(
|
||||
self,
|
||||
sid: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> None:
|
||||
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def touch_last_seen(self, sid: str, *, ip: str | None, now: datetime) -> None:
|
||||
values = {'last_seen_at': now}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
await self._session.execute(
|
||||
update(Session)
|
||||
.where(Session.sid == sid, Session.revoked_at.is_(None))
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
|
||||
async def rotate_refresh_if_match(
|
||||
self,
|
||||
*,
|
||||
sid: str,
|
||||
old_jti_hash: str,
|
||||
new_jti_hash: str,
|
||||
new_refresh_expires_at: datetime,
|
||||
now: datetime,
|
||||
ip: str | None,
|
||||
user_agent: str | None,
|
||||
) -> bool:
|
||||
values = {
|
||||
'refresh_jti_hash': new_jti_hash,
|
||||
'refresh_expires_at': new_refresh_expires_at,
|
||||
'last_seen_at': now,
|
||||
'user_agent': user_agent,
|
||||
}
|
||||
if ip is not None:
|
||||
values['last_ip'] = ip
|
||||
|
||||
res = await self._session.execute(
|
||||
update(Session)
|
||||
.where(
|
||||
Session.sid == sid,
|
||||
Session.revoked_at.is_(None),
|
||||
Session.refresh_jti_hash == old_jti_hash, # ✅ защита от гонок
|
||||
)
|
||||
.values(**values)
|
||||
.execution_options(synchronize_session='fetch')
|
||||
)
|
||||
await self._session.flush()
|
||||
return (res.rowcount or 0) > 0
|
||||
118
src/infrastructure/database/repositories/user_repository.py
Normal file
118
src/infrastructure/database/repositories/user_repository.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
from fastapi import status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.application.contracts import ILogger
|
||||
from src.application.domain.exceptions import ApplicationException
|
||||
from src.application.abstractions.repositories import IUserRepository
|
||||
from src.application.domain.entities import UserEntity
|
||||
from src.infrastructure.database.models import UserModel
|
||||
|
||||
|
||||
class UserRepository(IUserRepository):
|
||||
def __init__(self, session: AsyncSession, logger: ILogger):
|
||||
self._session = session
|
||||
self._logger = logger
|
||||
|
||||
async def _get_active_user(self, user_id: str) -> UserModel:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.id == user_id,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
user: UserModel | None = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
self._logger.warning(f'User not found with user_id {user_id}')
|
||||
raise ApplicationException(status_code=status.HTTP_404_NOT_FOUND, message='User not found')
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def _to_entity(user: UserModel) -> UserEntity:
|
||||
return UserEntity(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
password_hash=None,
|
||||
first_name=user.first_name,
|
||||
middle_name=user.middle_name,
|
||||
last_name=user.last_name,
|
||||
birth_date=user.birth_date,
|
||||
crypto_wallet=user.crypto_wallet,
|
||||
phone=user.phone,
|
||||
bik=user.bik,
|
||||
account_number=user.account_number,
|
||||
card_number=user.card_number,
|
||||
inn=user.inn,
|
||||
kyc_verified_at=user.kyc_verified_at,
|
||||
kyc_verified=user.kyc_verified,
|
||||
is_deleted=user.is_deleted,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> UserEntity:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
return self._to_entity(user)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def _update_field(self, user_id: str, **fields: object) -> UserEntity:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
for key, value in fields.items():
|
||||
setattr(user, key, value)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(user)
|
||||
return self._to_entity(user)
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def set_phone(self, user_id: str, phone: str) -> UserEntity:
|
||||
return await self._update_field(user_id, phone=phone)
|
||||
|
||||
async def set_bank_details(self, user_id: str, **fields: str) -> UserEntity:
|
||||
return await self._update_field(user_id, **fields)
|
||||
|
||||
async def set_crypto_wallet(self, user_id: str, wallet_address: str) -> UserEntity:
|
||||
return await self._update_field(user_id, crypto_wallet=wallet_address)
|
||||
|
||||
async def get_password_hash(self, user_id: str) -> str:
|
||||
try:
|
||||
user = await self._get_active_user(user_id)
|
||||
return user.password_hash
|
||||
except ApplicationException:
|
||||
raise
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
|
||||
async def set_password(self, user_id: str, password_hash: str) -> UserEntity:
|
||||
return await self._update_field(user_id, password_hash=password_hash)
|
||||
|
||||
async def set_email(self, user_id: str, email: str) -> UserEntity:
|
||||
return await self._update_field(user_id, email=email)
|
||||
|
||||
async def email_exists(self, email: str) -> bool:
|
||||
try:
|
||||
stmt = (
|
||||
select(UserModel)
|
||||
.where(
|
||||
UserModel.email == email,
|
||||
UserModel.is_deleted.is_(False),
|
||||
)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none() is not None
|
||||
except SQLAlchemyError as exception:
|
||||
self._logger.exception(str(exception))
|
||||
raise ApplicationException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message=f'Database error: {str(exception)}')
|
||||
42
src/infrastructure/database/unit_of_work.py
Normal file
42
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._session_repository: ISessionRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository:
|
||||
if self._session_repository is None:
|
||||
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
|
||||
return self._session_repository
|
||||
Reference in New Issue
Block a user