from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from src.application.abstractions import IUnitOfWork from src.application.abstractions.repositories import ( IUserRepository, ILegalEntityRepository, IPurchaseRequestRepository, ) from src.application.contracts import ILogger from src.infrastructure.database.repositories import ( UserRepository, LegalEntityRepository, PurchaseRequestRepository, ) class UnitOfWork(IUnitOfWork): def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger): self.session_factory = session_factory self._session: AsyncSession = None self._user_repository: IUserRepository = None self._legal_entity_repository: ILegalEntityRepository = None self._purchase_request_repository: IPurchaseRequestRepository = None self._logger: ILogger = logger async def __aenter__(self): self._logger.debug('UnitOfWork enter') self._user_repository = None self._legal_entity_repository = None self._purchase_request_repository = None self._session = self.session_factory() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type: self._logger.error(f'UnitOfWork rollback_on_error exc_type={exc_type.__name__} exc_val={exc_val!r}') await self._session.rollback() self._logger.debug(f'UnitOfWork session rollback done exc_type={exc_type.__name__}') else: await self._session.flush() await self._session.commit() self._logger.debug('UnitOfWork commit') await self._session.close() self._logger.debug('UnitOfWork exit session closed') @property def user_repository(self) -> IUserRepository: if self._user_repository is None: self._user_repository = UserRepository(session=self._session, logger=self._logger) return self._user_repository @property def legal_entity_repository(self) -> ILegalEntityRepository: if self._legal_entity_repository is None: self._legal_entity_repository = LegalEntityRepository(session=self._session, logger=self._logger) return self._legal_entity_repository @property def purchase_request_repository(self) -> IPurchaseRequestRepository: if self._purchase_request_repository is None: self._purchase_request_repository = PurchaseRequestRepository(session=self._session, logger=self._logger) return self._purchase_request_repository