Coverage for src/prisma/_base_client.py: 90%

209 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1from __future__ import annotations 

2 

3import logging 

4import warnings 

5from types import TracebackType 

6from typing import Any, Generic, TypeVar, overload 

7from pathlib import Path 

8from datetime import timedelta 

9from typing_extensions import Self, Literal 

10 

11from pydantic import BaseModel 

12 

13from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride 

14from .engine import ( 

15 SyncQueryEngine, 

16 AsyncQueryEngine, 

17 BaseAbstractEngine, 

18 SyncAbstractEngine, 

19 AsyncAbstractEngine, 

20) 

21from .errors import ClientNotConnectedError, ClientNotRegisteredError 

22from ._compat import model_parse, removeprefix 

23from ._builder import QueryBuilder 

24from ._metrics import Metrics 

25from ._registry import get_client 

26from .generator.models import EngineType 

27 

28log: logging.Logger = logging.getLogger(__name__) 

29 

30 

31class UseClientDefault: 

32 """For certain parameters such as `timeout=...` we can make our intent more clear 

33 by typing the parameter with this class rather than using None, for example: 

34 

35 ```py 

36 def connect(timeout: Union[int, timedelta, UseClientDefault] = UseClientDefault()) -> None: ... 

37 ``` 

38 

39 relays the intention more clearly than: 

40 

41 ```py 

42 def connect(timeout: Union[int, timedelta, None] = None) -> None: ... 

43 ``` 

44 

45 This solution also allows us to indicate an "unset" state that is uniquely distinct 

46 from `None` which may be useful in the future. 

47 """ 

48 

49 

50USE_CLIENT_DEFAULT = UseClientDefault() 

51 

52 

53def load_env(*, override: bool = False, **kwargs: Any) -> None: 

54 """Load environemntal variables from dotenv files 

55 

56 Loads from the following files relative to the current 

57 working directory: 

58 

59 - .env 

60 - prisma/.env 

61 """ 

62 from dotenv import load_dotenv 

63 

64 load_dotenv('.env', override=override, **kwargs) 

65 load_dotenv('prisma/.env', override=override, **kwargs) 

66 

67 

68_EngineT = TypeVar('_EngineT', bound=BaseAbstractEngine) 

69 

70 

71class BasePrisma(Generic[_EngineT]): 

72 _log_queries: bool 

73 _datasource: DatasourceOverride | None 

74 _connect_timeout: int | timedelta 

75 _tx_id: TransactionId | None 

76 _http_config: HttpConfig 

77 _internal_engine: _EngineT | None 

78 _copied: bool 

79 

80 # from generation 

81 _schema_path: Path 

82 _prisma_models: set[str] 

83 _packaged_schema_path: Path 

84 _engine_type: EngineType 

85 _default_datasource_name: str 

86 _relational_field_mappings: dict[str, dict[str, str]] 

87 

88 __slots__ = ( 

89 '_copied', 

90 '_tx_id', 

91 '_datasource', 

92 '_log_queries', 

93 '_http_config', 

94 '_schema_path', 

95 '_engine_type', 

96 '_prisma_models', 

97 '_active_provider', 

98 '_connect_timeout', 

99 '_internal_engine', 

100 '_packaged_schema_path', 

101 '_default_datasource_name', 

102 '_relational_field_mappings', 

103 ) 

104 

105 def __init__( 

106 self, 

107 *, 

108 use_dotenv: bool, 

109 log_queries: bool, 

110 datasource: DatasourceOverride | None, 

111 connect_timeout: int | timedelta, 

112 http: HttpConfig | None, 

113 ) -> None: 

114 # NOTE: if you add any more properties here then you may also need to forward 

115 # them in the `_copy()` method. 

116 self._internal_engine = None 

117 self._log_queries = log_queries 

118 self._datasource = datasource 

119 

120 if isinstance(connect_timeout, int): 

121 message = ( 

122 'Passing an int as `connect_timeout` argument is deprecated ' 

123 'and will be removed in the next major release. ' 

124 'Use a `datetime.timedelta` instance instead.' 

125 ) 

126 warnings.warn(message, DeprecationWarning, stacklevel=2) 

127 connect_timeout = timedelta(seconds=connect_timeout) 

128 

129 self._connect_timeout = connect_timeout 

130 self._http_config: HttpConfig = http or {} 

131 self._tx_id: TransactionId | None = None 

132 self._copied: bool = False 

133 

134 if use_dotenv: 

135 load_env() 

136 

137 def _set_generated_properties( 

138 self, 

139 *, 

140 schema_path: Path, 

141 engine_type: EngineType, 

142 packaged_schema_path: Path, 

143 active_provider: str, 

144 prisma_models: set[str], 

145 relational_field_mappings: dict[str, dict[str, str]], 

146 default_datasource_name: str, 

147 ) -> None: 

148 """We pass through generated metadata using this method 

149 instead of the `__init__()` because that causes weirdness 

150 for our `_copy()` method as this base class has arguments 

151 that the subclasses do not. 

152 """ 

153 self._schema_path = schema_path 

154 self._engine_type = engine_type 

155 self._prisma_models = prisma_models 

156 self._active_provider = active_provider 

157 self._packaged_schema_path = packaged_schema_path 

158 self._relational_field_mappings = relational_field_mappings 

159 self._default_datasource_name = default_datasource_name 

160 

161 @property 

162 def _default_datasource(self) -> Datasource: 

163 raise NotImplementedError('`_default_datasource` should be implemented in a subclass') 

164 

165 def is_registered(self) -> bool: 

166 """Returns True if this client instance is registered""" 

167 try: 

168 return get_client() is self 

169 except ClientNotRegisteredError: 

170 return False 

171 

172 def is_transaction(self) -> bool: 

173 """Returns True if the client is wrapped within a transaction""" 

174 return self._tx_id is not None 

175 

176 def is_connected(self) -> bool: 

177 """Returns True if the client is connected to the query engine, False otherwise.""" 

178 return self._internal_engine is not None 

179 

180 def __del__(self) -> None: 

181 # Note: as the transaction manager holds a reference to the original 

182 # client as well as the transaction client the original client cannot 

183 # be `free`d before the transaction is finished. So stopping the engine 

184 # here should be safe. 

185 if self._internal_engine is not None and not self._copied: 

186 log.debug('unclosed client - stopping engine') 

187 engine = self._internal_engine 

188 self._internal_engine = None 

189 engine.stop() 

190 

191 @property 

192 def _engine(self) -> _EngineT: 

193 engine = self._internal_engine 

194 if engine is None: 

195 raise ClientNotConnectedError() 

196 return engine 

197 

198 @_engine.setter 

199 def _engine(self, engine: _EngineT) -> None: 

200 self._internal_engine = engine 

201 

202 def _copy(self) -> Self: 

203 """Return a new Prisma instance using the same engine process (if connected). 

204 

205 This is only intended for private usage, there are no guarantees around this API. 

206 """ 

207 new = self.__class__( 

208 use_dotenv=False, 

209 http=self._http_config, 

210 datasource=self._datasource, 

211 log_queries=self._log_queries, 

212 connect_timeout=self._connect_timeout, 

213 ) 

214 new._copied = True 

215 

216 if self._internal_engine is not None: 

217 new._engine = self._internal_engine 

218 

219 return new 

220 

221 def _make_sqlite_datasource(self) -> DatasourceOverride: 

222 """Override the default SQLite path to protect against 

223 https://github.com/RobertCraigie/prisma-client-py/issues/409 

224 """ 

225 return { 

226 'name': self._default_datasource['name'], 

227 'url': self._make_sqlite_url(self._default_datasource['url']), 

228 } 

229 

230 def _make_sqlite_url(self, url: str, *, relative_to: Path | None = None) -> str: 

231 url_path = removeprefix(removeprefix(url, 'file:'), 'sqlite:') 

232 if url_path == url: 

233 return url 

234 

235 if Path(url_path).is_absolute(): 

236 return url 

237 

238 if relative_to is None: 

239 relative_to = self._schema_path.parent 

240 

241 return f'file:{relative_to.joinpath(url_path).resolve()}' 

242 

243 def _prepare_connect_args( 

244 self, 

245 *, 

246 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

247 ) -> tuple[timedelta, list[DatasourceOverride] | None]: 

248 """Returns (timeout, datasources) to be passed to `AbstractEngine.connect()`""" 

249 if isinstance(timeout, UseClientDefault): 

250 timeout = self._connect_timeout 

251 

252 if isinstance(timeout, int): 

253 message = ( 

254 'Passing an int as `timeout` argument is deprecated ' 

255 'and will be removed in the next major release. ' 

256 'Use a `datetime.timedelta` instance instead.' 

257 ) 

258 warnings.warn(message, DeprecationWarning, stacklevel=2) 

259 timeout = timedelta(seconds=timeout) 

260 

261 datasources: list[DatasourceOverride] | None = None 

262 if self._datasource is not None: 

263 ds = self._datasource.copy() 

264 ds.setdefault('name', self._default_datasource_name) 

265 datasources = [ds] 

266 elif self._active_provider == 'sqlite': 

267 # Override the default SQLite path to protect against 

268 # https://github.com/RobertCraigie/prisma-client-py/issues/409 

269 datasources = [self._make_sqlite_datasource()] 

270 

271 return timeout, datasources 

272 

273 def _make_query_builder( 

274 self, 

275 *, 

276 method: PrismaMethod, 

277 arguments: dict[str, Any], 

278 model: type[BaseModel] | None, 

279 root_selection: list[str] | None, 

280 ) -> QueryBuilder: 

281 return QueryBuilder( 

282 method=method, 

283 model=model, 

284 arguments=arguments, 

285 root_selection=root_selection, 

286 prisma_models=self._prisma_models, 

287 relational_field_mappings=self._relational_field_mappings, 

288 ) 

289 

290 

291class SyncBasePrisma(BasePrisma[SyncAbstractEngine]): 

292 __slots__ = () 

293 

294 def connect( 

295 self, 

296 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

297 ) -> None: 

298 """Connect to the Prisma query engine. 

299 

300 It is required to call this before accessing data. 

301 """ 

302 if self._internal_engine is None: 302 ↛ 305line 302 didn't jump to line 305, because the condition on line 302 was never false

303 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path) 

304 

305 timeout, datasources = self._prepare_connect_args(timeout=timeout) 

306 

307 self._internal_engine.connect( 

308 timeout=timeout, 

309 datasources=datasources, 

310 ) 

311 

312 def disconnect(self, timeout: float | timedelta | None = None) -> None: 

313 """Disconnect the Prisma query engine.""" 

314 if self._internal_engine is not None: 314 ↛ exitline 314 didn't return from function 'disconnect', because the condition on line 314 was never false

315 engine = self._internal_engine 

316 self._internal_engine = None 

317 

318 if isinstance(timeout, (int, float)): 318 ↛ 319line 318 didn't jump to line 319

319 message = ( 

320 'Passing a number as `timeout` argument is deprecated ' 

321 'and will be removed in the next major release. ' 

322 'Use a `datetime.timedelta` instead.' 

323 ) 

324 warnings.warn(message, DeprecationWarning, stacklevel=2) 

325 timeout = timedelta(seconds=timeout) 

326 

327 engine.close(timeout=timeout) 

328 engine.stop(timeout=timeout) 

329 

330 def __enter__(self) -> Self: 

331 self.connect() 

332 return self 

333 

334 def __exit__( 

335 self, 

336 exc_type: type[BaseException] | None, 

337 exc: BaseException | None, 

338 exc_tb: TracebackType | None, 

339 ) -> None: 

340 if self.is_connected(): 

341 self.disconnect() 

342 

343 @overload 

344 def get_metrics( 

345 self, 

346 format: Literal['json'] = 'json', 

347 *, 

348 global_labels: dict[str, str] | None = None, 

349 ) -> Metrics: ... 

350 

351 @overload 

352 def get_metrics( 

353 self, 

354 format: Literal['prometheus'], 

355 *, 

356 global_labels: dict[str, str] | None = None, 

357 ) -> str: ... 

358 

359 def get_metrics( 

360 self, 

361 format: MetricsFormat = 'json', 

362 *, 

363 global_labels: dict[str, str] | None = None, 

364 ) -> str | Metrics: 

365 """Metrics give you a detailed insight into how the Prisma Client interacts with your database. 

366 

367 You can retrieve metrics in either JSON or Prometheus formats. 

368 

369 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics. 

370 """ 

371 response = self._engine.metrics(format=format, global_labels=global_labels) 

372 if format == 'prometheus': 

373 # For the prometheus format we return the response as-is 

374 assert isinstance(response, str) 

375 return response 

376 

377 return model_parse(Metrics, response) 

378 

379 def _create_engine(self, dml_path: Path | None = None) -> SyncAbstractEngine: 

380 if self._engine_type == EngineType.binary: 380 ↛ 387line 380 didn't jump to line 387, because the condition on line 380 was never false

381 return SyncQueryEngine( 

382 dml_path=dml_path or self._packaged_schema_path, 

383 log_queries=self._log_queries, 

384 http_config=self._http_config, 

385 ) 

386 

387 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}') 

388 

389 @property 

390 def _engine_class(self) -> type[SyncAbstractEngine]: 

391 if self._engine_type == EngineType.binary: 

392 return SyncQueryEngine 

393 

394 raise RuntimeError(f'Unhandled engine type: {self._engine_type}') 

395 

396 # TODO: don't return Any 

397 def _execute( 

398 self, 

399 method: PrismaMethod, 

400 arguments: dict[str, Any], 

401 model: type[BaseModel] | None = None, 

402 root_selection: list[str] | None = None, 

403 ) -> Any: 

404 builder = self._make_query_builder( 

405 method=method, model=model, arguments=arguments, root_selection=root_selection 

406 ) 

407 return self._engine.query(builder.build(), tx_id=self._tx_id) 

408 

409 

410class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]): 

411 __slots__ = () 

412 

413 async def connect( 

414 self, 

415 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

416 ) -> None: 

417 """Connect to the Prisma query engine. 

418 

419 It is required to call this before accessing data. 

420 """ 

421 if self._internal_engine is None: 

422 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path) 

423 

424 timeout, datasources = self._prepare_connect_args(timeout=timeout) 

425 

426 await self._internal_engine.connect( 

427 timeout=timeout, 

428 datasources=datasources, 

429 ) 

430 

431 async def disconnect(self, timeout: float | timedelta | None = None) -> None: 

432 """Disconnect the Prisma query engine.""" 

433 if self._internal_engine is not None: 433 ↛ exitline 433 didn't return from function 'disconnect', because the condition on line 433 was never false

434 engine = self._internal_engine 

435 self._internal_engine = None 

436 

437 if isinstance(timeout, (int, float)): 

438 message = ( 

439 'Passing a number as `timeout` argument is deprecated ' 

440 'and will be removed in the next major release. ' 

441 'Use a `datetime.timedelta` instead.' 

442 ) 

443 warnings.warn(message, DeprecationWarning, stacklevel=2) 

444 timeout = timedelta(seconds=timeout) 

445 

446 await engine.aclose(timeout=timeout) 

447 engine.stop(timeout=timeout) 

448 

449 async def __aenter__(self) -> Self: 

450 await self.connect() 

451 return self 

452 

453 async def __aexit__( 

454 self, 

455 exc_type: type[BaseException] | None, 

456 exc: BaseException | None, 

457 exc_tb: TracebackType | None, 

458 ) -> None: 

459 if self.is_connected(): 459 ↛ exitline 459 didn't return from function '__aexit__', because the condition on line 459 was never false

460 await self.disconnect() 

461 

462 @overload 

463 async def get_metrics( 

464 self, 

465 format: Literal['json'] = 'json', 

466 *, 

467 global_labels: dict[str, str] | None = None, 

468 ) -> Metrics: ... 

469 

470 @overload 

471 async def get_metrics( 

472 self, 

473 format: Literal['prometheus'], 

474 *, 

475 global_labels: dict[str, str] | None = None, 

476 ) -> str: ... 

477 

478 async def get_metrics( 

479 self, 

480 format: MetricsFormat = 'json', 

481 *, 

482 global_labels: dict[str, str] | None = None, 

483 ) -> str | Metrics: 

484 """Metrics give you a detailed insight into how the Prisma Client interacts with your database. 

485 

486 You can retrieve metrics in either JSON or Prometheus formats. 

487 

488 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics. 

489 """ 

490 response = await self._engine.metrics(format=format, global_labels=global_labels) 

491 if format == 'prometheus': 

492 # For the prometheus format we return the response as-is 

493 assert isinstance(response, str) 

494 return response 

495 

496 return model_parse(Metrics, response) 

497 

498 def _create_engine(self, dml_path: Path | None = None) -> AsyncAbstractEngine: 

499 if self._engine_type == EngineType.binary: 499 ↛ 506line 499 didn't jump to line 506, because the condition on line 499 was never false

500 return AsyncQueryEngine( 

501 dml_path=dml_path or self._packaged_schema_path, 

502 log_queries=self._log_queries, 

503 http_config=self._http_config, 

504 ) 

505 

506 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}') 

507 

508 @property 

509 def _engine_class(self) -> type[AsyncAbstractEngine]: 

510 if self._engine_type == EngineType.binary: 510 ↛ 513line 510 didn't jump to line 513, because the condition on line 510 was never false

511 return AsyncQueryEngine 

512 

513 raise RuntimeError(f'Unhandled engine type: {self._engine_type}') 

514 

515 # TODO: don't return Any 

516 async def _execute( 

517 self, 

518 *, 

519 method: PrismaMethod, 

520 arguments: dict[str, Any], 

521 model: type[BaseModel] | None = None, 

522 root_selection: list[str] | None = None, 

523 ) -> Any: 

524 builder = self._make_query_builder( 

525 method=method, model=model, arguments=arguments, root_selection=root_selection 

526 ) 

527 return await self._engine.query(builder.build(), tx_id=self._tx_id)