Files
b2b/src/infrastructure/database/unit_of_work.py
2026-06-03 13:49:16 +03:00

62 lines
2.5 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from src.application.abstractions import IUnitOfWork
from src.application.abstractions.repositories import (
IUserRepository,
ILegalEntityRepository,
IPurchaseRequestRepository,
)
from src.application.contracts import ILogger
from src.infrastructure.database.repositories import (
UserRepository,
LegalEntityRepository,
PurchaseRequestRepository,
)
class UnitOfWork(IUnitOfWork):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
self.session_factory = session_factory
self._session: AsyncSession = None
self._user_repository: IUserRepository = None
self._legal_entity_repository: ILegalEntityRepository = None
self._purchase_request_repository: IPurchaseRequestRepository = None
self._logger: ILogger = logger
async def __aenter__(self):
self._logger.debug('UnitOfWork enter')
self._user_repository = None
self._legal_entity_repository = None
self._purchase_request_repository = None
self._session = self.session_factory()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._logger.error(f'UnitOfWork rollback_on_error exc_type={exc_type.__name__} exc_val={exc_val!r}')
await self._session.rollback()
self._logger.debug(f'UnitOfWork session rollback done exc_type={exc_type.__name__}')
else:
await self._session.flush()
await self._session.commit()
self._logger.debug('UnitOfWork commit')
await self._session.close()
self._logger.debug('UnitOfWork exit session closed')
@property
def user_repository(self) -> IUserRepository:
if self._user_repository is None:
self._user_repository = UserRepository(session=self._session, logger=self._logger)
return self._user_repository
@property
def legal_entity_repository(self) -> ILegalEntityRepository:
if self._legal_entity_repository is None:
self._legal_entity_repository = LegalEntityRepository(session=self._session, logger=self._logger)
return self._legal_entity_repository
@property
def purchase_request_repository(self) -> IPurchaseRequestRepository:
if self._purchase_request_repository is None:
self._purchase_request_repository = PurchaseRequestRepository(session=self._session, logger=self._logger)
return self._purchase_request_repository