Coverage for src/prisma/_transactions.py: 93%
96 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import logging
4import warnings
5from types import TracebackType
6from typing import TYPE_CHECKING, Generic, TypeVar
7from datetime import timedelta
9from ._types import TransactionId
10from .errors import TransactionNotStartedError
11from ._builder import dumps
13if TYPE_CHECKING:
14 from ._base_client import SyncBasePrisma, AsyncBasePrisma
16log: logging.Logger = logging.getLogger(__name__)
19_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma')
20_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma')
23class AsyncTransactionManager(Generic[_AsyncPrismaT]):
24 """Context manager for wrapping a Prisma instance within a transaction.
26 This should never be created manually, instead it should be used
27 through the Prisma.tx() method.
28 """
30 def __init__(
31 self,
32 *,
33 client: _AsyncPrismaT,
34 max_wait: int | timedelta,
35 timeout: int | timedelta,
36 ) -> None:
37 self.__client = client
39 if isinstance(max_wait, int):
40 message = (
41 'Passing an int as `max_wait` argument is deprecated '
42 'and will be removed in the next major release. '
43 'Use a `datetime.timedelta` instance instead.'
44 )
45 warnings.warn(message, DeprecationWarning, stacklevel=3)
46 max_wait = timedelta(milliseconds=max_wait)
48 self._max_wait = max_wait
50 if isinstance(timeout, int):
51 message = (
52 'Passing an int as `timeout` argument is deprecated '
53 'and will be removed in the next major release. '
54 'Use a `datetime.timedelta` instance instead.'
55 )
56 warnings.warn(message, DeprecationWarning, stacklevel=3)
57 timeout = timedelta(milliseconds=timeout)
59 self._timeout = timeout
61 self._tx_id: TransactionId | None = None
63 async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT:
64 """Start the transaction and return the wrapped Prisma instance"""
65 if self.__client.is_transaction():
66 # if we were called from the context manager then the stacklevel
67 # needs to be one higher to warn on the actual offending code
68 warnings.warn(
69 'The current client is already in a transaction. This can lead to surprising behaviour.',
70 UserWarning,
71 stacklevel=3 if _from_context else 2,
72 )
74 tx_id = await self.__client._engine.start_transaction(
75 content=dumps(
76 {
77 'timeout': int(self._timeout.total_seconds() * 1000),
78 'max_wait': int(self._max_wait.total_seconds() * 1000),
79 }
80 ),
81 )
82 self._tx_id = tx_id
83 client = self.__client._copy()
84 client._tx_id = tx_id
85 return client
87 async def commit(self) -> None:
88 """Commit the transaction to the database, this transaction will no longer be usable"""
89 if self._tx_id is None:
90 raise TransactionNotStartedError()
92 await self.__client._engine.commit_transaction(self._tx_id)
94 async def rollback(self) -> None:
95 """Do not commit the changes to the database, this transaction will no longer be usable"""
96 if self._tx_id is None:
97 raise TransactionNotStartedError()
99 await self.__client._engine.rollback_transaction(self._tx_id)
101 async def __aenter__(self) -> _AsyncPrismaT:
102 return await self.start(_from_context=True)
104 async def __aexit__(
105 self,
106 exc_type: type[BaseException] | None,
107 exc: BaseException | None,
108 exc_tb: TracebackType | None,
109 ) -> None:
110 if exc is None:
111 log.debug('Transaction exited with no exception - commiting')
112 await self.commit()
113 return
115 log.debug('Transaction exited with exc type: %s - rolling back', exc_type)
117 try:
118 await self.rollback()
119 except Exception as exc:
120 log.warning(
121 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', exc
122 )
125class SyncTransactionManager(Generic[_SyncPrismaT]):
126 """Context manager for wrapping a Prisma instance within a transaction.
128 This should never be created manually, instead it should be used
129 through the Prisma.tx() method.
130 """
132 def __init__(
133 self,
134 *,
135 client: _SyncPrismaT,
136 max_wait: int | timedelta,
137 timeout: int | timedelta,
138 ) -> None:
139 self.__client = client
141 if isinstance(max_wait, int): 141 ↛ 142line 141 didn't jump to line 142
142 message = (
143 'Passing an int as `max_wait` argument is deprecated '
144 'and will be removed in the next major release. '
145 'Use a `datetime.timedelta` instance instead.'
146 )
147 warnings.warn(message, DeprecationWarning, stacklevel=3)
148 max_wait = timedelta(milliseconds=max_wait)
150 self._max_wait = max_wait
152 if isinstance(timeout, int): 152 ↛ 153line 152 didn't jump to line 153
153 message = (
154 'Passing an int as `timeout` argument is deprecated '
155 'and will be removed in the next major release. '
156 'Use a `datetime.timedelta` instance instead.'
157 )
158 warnings.warn(message, DeprecationWarning, stacklevel=3)
159 timeout = timedelta(milliseconds=timeout)
161 self._timeout = timeout
163 self._tx_id: TransactionId | None = None
165 def start(self, *, _from_context: bool = False) -> _SyncPrismaT:
166 """Start the transaction and return the wrapped Prisma instance"""
167 if self.__client.is_transaction():
168 # if we were called from the context manager then the stacklevel
169 # needs to be one higher to warn on the actual offending code
170 warnings.warn(
171 'The current client is already in a transaction. This can lead to surprising behaviour.',
172 UserWarning,
173 stacklevel=3 if _from_context else 2,
174 )
176 tx_id = self.__client._engine.start_transaction(
177 content=dumps(
178 {
179 'timeout': int(self._timeout.total_seconds() * 1000),
180 'max_wait': int(self._max_wait.total_seconds() * 1000),
181 }
182 ),
183 )
184 self._tx_id = tx_id
185 client = self.__client._copy()
186 client._tx_id = tx_id
187 return client
189 def commit(self) -> None:
190 """Commit the transaction to the database, this transaction will no longer be usable"""
191 if self._tx_id is None:
192 raise TransactionNotStartedError()
194 self.__client._engine.commit_transaction(self._tx_id)
196 def rollback(self) -> None:
197 """Do not commit the changes to the database, this transaction will no longer be usable"""
198 if self._tx_id is None:
199 raise TransactionNotStartedError()
201 self.__client._engine.rollback_transaction(self._tx_id)
203 def __enter__(self) -> _SyncPrismaT:
204 return self.start(_from_context=True)
206 def __exit__(
207 self,
208 exc_type: type[BaseException] | None,
209 exc: BaseException | None,
210 exc_tb: TracebackType | None,
211 ) -> None:
212 if exc is None:
213 log.debug('Transaction exited with no exception - commiting')
214 self.commit()
215 return
217 log.debug('Transaction exited with exc type: %s - rolling back', exc_type)
219 try:
220 self.rollback()
221 except Exception as exc:
222 log.warning(
223 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', exc
224 )