Coverage for src/prisma/_transactions.py: 93%

96 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from __future__ import annotations 

2 

3import logging 

4import warnings 

5from types import TracebackType 

6from typing import TYPE_CHECKING, Generic, TypeVar 

7from datetime import timedelta 

8 

9from ._types import TransactionId 

10from .errors import TransactionNotStartedError 

11from ._builder import dumps 

12 

13if TYPE_CHECKING: 

14 from ._base_client import SyncBasePrisma, AsyncBasePrisma 

15 

16log: logging.Logger = logging.getLogger(__name__) 

17 

18 

19_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma') 

20_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma') 

21 

22 

23class AsyncTransactionManager(Generic[_AsyncPrismaT]): 

24 """Context manager for wrapping a Prisma instance within a transaction. 

25 

26 This should never be created manually, instead it should be used 

27 through the Prisma.tx() method. 

28 """ 

29 

30 def __init__( 

31 self, 

32 *, 

33 client: _AsyncPrismaT, 

34 max_wait: int | timedelta, 

35 timeout: int | timedelta, 

36 ) -> None: 

37 self.__client = client 

38 

39 if isinstance(max_wait, int): 

40 message = ( 

41 'Passing an int as `max_wait` argument is deprecated ' 

42 'and will be removed in the next major release. ' 

43 'Use a `datetime.timedelta` instance instead.' 

44 ) 

45 warnings.warn(message, DeprecationWarning, stacklevel=3) 

46 max_wait = timedelta(milliseconds=max_wait) 

47 

48 self._max_wait = max_wait 

49 

50 if isinstance(timeout, int): 

51 message = ( 

52 'Passing an int as `timeout` argument is deprecated ' 

53 'and will be removed in the next major release. ' 

54 'Use a `datetime.timedelta` instance instead.' 

55 ) 

56 warnings.warn(message, DeprecationWarning, stacklevel=3) 

57 timeout = timedelta(milliseconds=timeout) 

58 

59 self._timeout = timeout 

60 

61 self._tx_id: TransactionId | None = None 

62 

63 async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT: 

64 """Start the transaction and return the wrapped Prisma instance""" 

65 if self.__client.is_transaction(): 

66 # if we were called from the context manager then the stacklevel 

67 # needs to be one higher to warn on the actual offending code 

68 warnings.warn( 

69 'The current client is already in a transaction. This can lead to surprising behaviour.', 

70 UserWarning, 

71 stacklevel=3 if _from_context else 2, 

72 ) 

73 

74 tx_id = await self.__client._engine.start_transaction( 

75 content=dumps( 

76 { 

77 'timeout': int(self._timeout.total_seconds() * 1000), 

78 'max_wait': int(self._max_wait.total_seconds() * 1000), 

79 } 

80 ), 

81 ) 

82 self._tx_id = tx_id 

83 client = self.__client._copy() 

84 client._tx_id = tx_id 

85 return client 

86 

87 async def commit(self) -> None: 

88 """Commit the transaction to the database, this transaction will no longer be usable""" 

89 if self._tx_id is None: 

90 raise TransactionNotStartedError() 

91 

92 await self.__client._engine.commit_transaction(self._tx_id) 

93 

94 async def rollback(self) -> None: 

95 """Do not commit the changes to the database, this transaction will no longer be usable""" 

96 if self._tx_id is None: 

97 raise TransactionNotStartedError() 

98 

99 await self.__client._engine.rollback_transaction(self._tx_id) 

100 

101 async def __aenter__(self) -> _AsyncPrismaT: 

102 return await self.start(_from_context=True) 

103 

104 async def __aexit__( 

105 self, 

106 exc_type: type[BaseException] | None, 

107 exc: BaseException | None, 

108 exc_tb: TracebackType | None, 

109 ) -> None: 

110 if exc is None: 

111 log.debug('Transaction exited with no exception - commiting') 

112 await self.commit() 

113 return 

114 

115 log.debug('Transaction exited with exc type: %s - rolling back', exc_type) 

116 

117 try: 

118 await self.rollback() 

119 except Exception as exc: 

120 log.warning( 

121 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', exc 

122 ) 

123 

124 

125class SyncTransactionManager(Generic[_SyncPrismaT]): 

126 """Context manager for wrapping a Prisma instance within a transaction. 

127 

128 This should never be created manually, instead it should be used 

129 through the Prisma.tx() method. 

130 """ 

131 

132 def __init__( 

133 self, 

134 *, 

135 client: _SyncPrismaT, 

136 max_wait: int | timedelta, 

137 timeout: int | timedelta, 

138 ) -> None: 

139 self.__client = client 

140 

141 if isinstance(max_wait, int): 141 ↛ 142line 141 didn't jump to line 142

142 message = ( 

143 'Passing an int as `max_wait` argument is deprecated ' 

144 'and will be removed in the next major release. ' 

145 'Use a `datetime.timedelta` instance instead.' 

146 ) 

147 warnings.warn(message, DeprecationWarning, stacklevel=3) 

148 max_wait = timedelta(milliseconds=max_wait) 

149 

150 self._max_wait = max_wait 

151 

152 if isinstance(timeout, int): 152 ↛ 153line 152 didn't jump to line 153

153 message = ( 

154 'Passing an int as `timeout` argument is deprecated ' 

155 'and will be removed in the next major release. ' 

156 'Use a `datetime.timedelta` instance instead.' 

157 ) 

158 warnings.warn(message, DeprecationWarning, stacklevel=3) 

159 timeout = timedelta(milliseconds=timeout) 

160 

161 self._timeout = timeout 

162 

163 self._tx_id: TransactionId | None = None 

164 

165 def start(self, *, _from_context: bool = False) -> _SyncPrismaT: 

166 """Start the transaction and return the wrapped Prisma instance""" 

167 if self.__client.is_transaction(): 

168 # if we were called from the context manager then the stacklevel 

169 # needs to be one higher to warn on the actual offending code 

170 warnings.warn( 

171 'The current client is already in a transaction. This can lead to surprising behaviour.', 

172 UserWarning, 

173 stacklevel=3 if _from_context else 2, 

174 ) 

175 

176 tx_id = self.__client._engine.start_transaction( 

177 content=dumps( 

178 { 

179 'timeout': int(self._timeout.total_seconds() * 1000), 

180 'max_wait': int(self._max_wait.total_seconds() * 1000), 

181 } 

182 ), 

183 ) 

184 self._tx_id = tx_id 

185 client = self.__client._copy() 

186 client._tx_id = tx_id 

187 return client 

188 

189 def commit(self) -> None: 

190 """Commit the transaction to the database, this transaction will no longer be usable""" 

191 if self._tx_id is None: 

192 raise TransactionNotStartedError() 

193 

194 self.__client._engine.commit_transaction(self._tx_id) 

195 

196 def rollback(self) -> None: 

197 """Do not commit the changes to the database, this transaction will no longer be usable""" 

198 if self._tx_id is None: 

199 raise TransactionNotStartedError() 

200 

201 self.__client._engine.rollback_transaction(self._tx_id) 

202 

203 def __enter__(self) -> _SyncPrismaT: 

204 return self.start(_from_context=True) 

205 

206 def __exit__( 

207 self, 

208 exc_type: type[BaseException] | None, 

209 exc: BaseException | None, 

210 exc_tb: TracebackType | None, 

211 ) -> None: 

212 if exc is None: 

213 log.debug('Transaction exited with no exception - commiting') 

214 self.commit() 

215 return 

216 

217 log.debug('Transaction exited with exc type: %s - rolling back', exc_type) 

218 

219 try: 

220 self.rollback() 

221 except Exception as exc: 

222 log.warning( 

223 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', exc 

224 )