Initial commit
This commit is contained in:
42
src/infrastructure/database/unit_of_work.py
Normal file
42
src/infrastructure/database/unit_of_work.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from src.application.abstractions import IUnitOfWork
|
||||
from src.application.abstractions.repositories import IUserRepository, ISessionRepository
|
||||
from src.application.contracts import ILogger
|
||||
from src.infrastructure.database.repositories import UserRepository, SessionRepository
|
||||
|
||||
|
||||
|
||||
class UnitOfWork(IUnitOfWork):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], logger: ILogger):
|
||||
self.session_factory = session_factory
|
||||
self._session: AsyncSession = None
|
||||
self._user_repository: IUserRepository = None
|
||||
self._session_repository: ISessionRepository = None
|
||||
self._logger: ILogger = logger
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self.session_factory()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type:
|
||||
self._logger.error(str(exc_val))
|
||||
await self._session.rollback()
|
||||
self._logger.error(f'Rollback: str{exc_val})')
|
||||
else:
|
||||
await self._session.flush()
|
||||
await self._session.commit()
|
||||
self._logger.debug('Commit')
|
||||
await self._session.close()
|
||||
|
||||
@property
|
||||
def user_repository(self) -> IUserRepository:
|
||||
if self._user_repository is None:
|
||||
self._user_repository = UserRepository(session=self._session, logger=self._logger)
|
||||
return self._user_repository
|
||||
|
||||
@property
|
||||
def session_repository(self) -> ISessionRepository:
|
||||
if self._session_repository is None:
|
||||
self._session_repository = SessionRepository(session=self._session, logger=self._logger)
|
||||
return self._session_repository
|
||||
Reference in New Issue
Block a user