Coverage for src/prisma/_base_client.py: 89%

218 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from __future__ import annotations 

2 

3import logging 

4import warnings 

5from types import TracebackType 

6from typing import Any, Generic, TypeVar, overload 

7from pathlib import Path 

8from datetime import timedelta 

9from typing_extensions import Self, Literal 

10 

11from pydantic import BaseModel 

12 

13from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride 

14from .engine import ( 

15 SyncQueryEngine, 

16 AsyncQueryEngine, 

17 BaseAbstractEngine, 

18 SyncAbstractEngine, 

19 AsyncAbstractEngine, 

20) 

21from .errors import ClientNotConnectedError, ClientNotRegisteredError 

22from ._compat import model_parse, removeprefix 

23from ._builder import QueryBuilder 

24from ._metrics import Metrics 

25from ._registry import get_client 

26from .generator.models import EngineType 

27 

28log: logging.Logger = logging.getLogger(__name__) 

29 

30 

31class UseClientDefault: 

32 """For certain parameters such as `timeout=...` we can make our intent more clear 

33 by typing the parameter with this class rather than using None, for example: 

34 

35 ```py 

36 def connect(timeout: Union[int, timedelta, UseClientDefault] = UseClientDefault()) -> None: ... 

37 ``` 

38 

39 relays the intention more clearly than: 

40 

41 ```py 

42 def connect(timeout: Union[int, timedelta, None] = None) -> None: ... 

43 ``` 

44 

45 This solution also allows us to indicate an "unset" state that is uniquely distinct 

46 from `None` which may be useful in the future. 

47 """ 

48 

49 

50USE_CLIENT_DEFAULT = UseClientDefault() 

51 

52 

53def load_env(*, override: bool = False, **kwargs: Any) -> None: 

54 """Load environemntal variables from dotenv files 

55 

56 Loads from the following files relative to the current 

57 working directory: 

58 

59 - .env 

60 - prisma/.env 

61 """ 

62 from dotenv import load_dotenv 

63 

64 load_dotenv('.env', override=override, **kwargs) 

65 load_dotenv('prisma/.env', override=override, **kwargs) 

66 

67 

68_EngineT = TypeVar('_EngineT', bound=BaseAbstractEngine) 

69 

70 

71class BasePrisma(Generic[_EngineT]): 

72 _log_queries: bool 

73 _datasource: DatasourceOverride | None 

74 _connect_timeout: int | timedelta 

75 _tx_id: TransactionId | None 

76 _http_config: HttpConfig 

77 _internal_engine: _EngineT | None 

78 _copied: bool 

79 

80 # from generation 

81 _schema_path: Path 

82 _prisma_models: set[str] 

83 _packaged_schema_path: Path 

84 _engine_type: EngineType 

85 _preview_features: set[str] 

86 _default_datasource_name: str 

87 _relational_field_mappings: dict[str, dict[str, str]] 

88 

89 __slots__ = ( 

90 '_copied', 

91 '_tx_id', 

92 '_datasource', 

93 '_log_queries', 

94 '_http_config', 

95 '_schema_path', 

96 '_engine_type', 

97 '_prisma_models', 

98 '_active_provider', 

99 '_connect_timeout', 

100 '_internal_engine', 

101 '_packaged_schema_path', 

102 '_preview_features', 

103 '_default_datasource_name', 

104 '_relational_field_mappings', 

105 ) 

106 

107 def __init__( 

108 self, 

109 *, 

110 use_dotenv: bool, 

111 log_queries: bool, 

112 datasource: DatasourceOverride | None, 

113 connect_timeout: int | timedelta, 

114 http: HttpConfig | None, 

115 ) -> None: 

116 # NOTE: if you add any more properties here then you may also need to forward 

117 # them in the `_copy()` method. 

118 self._internal_engine = None 

119 self._log_queries = log_queries 

120 self._datasource = datasource 

121 

122 if isinstance(connect_timeout, int): 

123 message = ( 

124 'Passing an int as `connect_timeout` argument is deprecated ' 

125 'and will be removed in the next major release. ' 

126 'Use a `datetime.timedelta` instance instead.' 

127 ) 

128 warnings.warn(message, DeprecationWarning, stacklevel=2) 

129 connect_timeout = timedelta(seconds=connect_timeout) 

130 

131 self._connect_timeout = connect_timeout 

132 self._http_config: HttpConfig = http or {} 

133 self._tx_id: TransactionId | None = None 

134 self._copied: bool = False 

135 

136 if use_dotenv: 

137 load_env() 

138 

139 def _set_generated_properties( 

140 self, 

141 *, 

142 schema_path: Path, 

143 engine_type: EngineType, 

144 packaged_schema_path: Path, 

145 active_provider: str, 

146 prisma_models: set[str], 

147 preview_features: set[str], 

148 relational_field_mappings: dict[str, dict[str, str]], 

149 default_datasource_name: str, 

150 ) -> None: 

151 """We pass through generated metadata using this method 

152 instead of the `__init__()` because that causes weirdness 

153 for our `_copy()` method as this base class has arguments 

154 that the subclasses do not. 

155 """ 

156 self._schema_path = schema_path 

157 self._engine_type = engine_type 

158 self._prisma_models = prisma_models 

159 self._active_provider = active_provider 

160 self._packaged_schema_path = packaged_schema_path 

161 self._preview_features = preview_features 

162 self._relational_field_mappings = relational_field_mappings 

163 self._default_datasource_name = default_datasource_name 

164 

165 @property 

166 def _default_datasource(self) -> Datasource: 

167 raise NotImplementedError('`_default_datasource` should be implemented in a subclass') 

168 

169 def is_registered(self) -> bool: 

170 """Returns True if this client instance is registered""" 

171 try: 

172 return get_client() is self 

173 except ClientNotRegisteredError: 

174 return False 

175 

176 def is_transaction(self) -> bool: 

177 """Returns True if the client is wrapped within a transaction""" 

178 return self._tx_id is not None 

179 

180 def is_connected(self) -> bool: 

181 """Returns True if the client is connected to the query engine, False otherwise.""" 

182 return self._internal_engine is not None 

183 

184 def __del__(self) -> None: 

185 # Note: as the transaction manager holds a reference to the original 

186 # client as well as the transaction client the original client cannot 

187 # be `free`d before the transaction is finished. So stopping the engine 

188 # here should be safe. 

189 if self._internal_engine is not None and not self._copied: 

190 log.debug('unclosed client - stopping engine') 

191 engine = self._internal_engine 

192 self._internal_engine = None 

193 engine.stop() 

194 

195 @property 

196 def _engine(self) -> _EngineT: 

197 engine = self._internal_engine 

198 if engine is None: 

199 raise ClientNotConnectedError() 

200 return engine 

201 

202 @_engine.setter 

203 def _engine(self, engine: _EngineT) -> None: 

204 self._internal_engine = engine 

205 

206 def _copy(self) -> Self: 

207 """Return a new Prisma instance using the same engine process (if connected). 

208 

209 This is only intended for private usage, there are no guarantees around this API. 

210 """ 

211 new = self.__class__( 

212 use_dotenv=False, 

213 http=self._http_config, 

214 datasource=self._datasource, 

215 log_queries=self._log_queries, 

216 connect_timeout=self._connect_timeout, 

217 ) 

218 new._copied = True 

219 

220 if self._internal_engine is not None: 

221 new._engine = self._internal_engine 

222 

223 return new 

224 

225 def _make_sqlite_datasource(self) -> DatasourceOverride: 

226 """Override the default SQLite path to protect against 

227 https://github.com/RobertCraigie/prisma-client-py/issues/409 

228 """ 

229 source_file_path: str | Path | None = self._default_datasource.get('source_file_path') 

230 if source_file_path: 230 ↛ 233line 230 didn't jump to line 233, because the condition on line 230 was never false

231 source_file_path = Path(source_file_path).parent 

232 

233 return { 

234 'name': self._default_datasource['name'], 

235 'url': self._make_sqlite_url( 

236 self._default_datasource['url'], 

237 relative_to=source_file_path, 

238 ), 

239 } 

240 

241 def _make_sqlite_url(self, url: str, *, relative_to: Path | str | None = None) -> str: 

242 url_path = removeprefix(removeprefix(url, 'file:'), 'sqlite:') 

243 if url_path == url: 

244 return url 

245 

246 if Path(url_path).is_absolute(): 

247 return url 

248 

249 if relative_to is None: 

250 relative_to = self._schema_path.parent 

251 

252 if isinstance(relative_to, str): 252 ↛ 253line 252 didn't jump to line 253, because the condition on line 252 was never true

253 relative_to = Path(relative_to) 

254 

255 return f'file:{relative_to.joinpath(url_path).resolve()}' 

256 

257 def _prepare_connect_args( 

258 self, 

259 *, 

260 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

261 ) -> tuple[timedelta, list[DatasourceOverride] | None]: 

262 """Returns (timeout, datasources) to be passed to `AbstractEngine.connect()`""" 

263 if isinstance(timeout, UseClientDefault): 

264 timeout = self._connect_timeout 

265 

266 if isinstance(timeout, int): 

267 message = ( 

268 'Passing an int as `timeout` argument is deprecated ' 

269 'and will be removed in the next major release. ' 

270 'Use a `datetime.timedelta` instance instead.' 

271 ) 

272 warnings.warn(message, DeprecationWarning, stacklevel=2) 

273 timeout = timedelta(seconds=timeout) 

274 

275 datasources: list[DatasourceOverride] | None = None 

276 if self._datasource is not None: 

277 ds = self._datasource.copy() 

278 ds.setdefault('name', self._default_datasource_name) 

279 datasources = [ds] 

280 elif self._active_provider == 'sqlite': 

281 log.debug('overriding default SQLite datasource path') 

282 # Override the default SQLite path to protect against 

283 # https://github.com/RobertCraigie/prisma-client-py/issues/409 

284 datasources = [self._make_sqlite_datasource()] 

285 

286 log.debug('datasources: %s', datasources) 

287 return timeout, datasources 

288 

289 def _make_query_builder( 

290 self, 

291 *, 

292 method: PrismaMethod, 

293 arguments: dict[str, Any], 

294 model: type[BaseModel] | None, 

295 root_selection: list[str] | None, 

296 ) -> QueryBuilder: 

297 return QueryBuilder( 

298 method=method, 

299 model=model, 

300 arguments=arguments, 

301 root_selection=root_selection, 

302 prisma_models=self._prisma_models, 

303 relational_field_mappings=self._relational_field_mappings, 

304 ) 

305 

306 

307class SyncBasePrisma(BasePrisma[SyncAbstractEngine]): 

308 __slots__ = () 

309 

310 def connect( 

311 self, 

312 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

313 ) -> None: 

314 """Connect to the Prisma query engine. 

315 

316 It is required to call this before accessing data. 

317 """ 

318 if self._internal_engine is None: 318 ↛ 321line 318 didn't jump to line 321, because the condition on line 318 was never false

319 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path) 

320 

321 timeout, datasources = self._prepare_connect_args(timeout=timeout) 

322 

323 self._internal_engine.connect( 

324 timeout=timeout, 

325 datasources=datasources, 

326 ) 

327 

328 def disconnect(self, timeout: float | timedelta | None = None) -> None: 

329 """Disconnect the Prisma query engine.""" 

330 if self._internal_engine is not None: 330 ↛ exitline 330 didn't return from function 'disconnect', because the condition on line 330 was never false

331 engine = self._internal_engine 

332 self._internal_engine = None 

333 

334 if isinstance(timeout, (int, float)): 334 ↛ 335line 334 didn't jump to line 335

335 message = ( 

336 'Passing a number as `timeout` argument is deprecated ' 

337 'and will be removed in the next major release. ' 

338 'Use a `datetime.timedelta` instead.' 

339 ) 

340 warnings.warn(message, DeprecationWarning, stacklevel=2) 

341 timeout = timedelta(seconds=timeout) 

342 

343 engine.close(timeout=timeout) 

344 engine.stop(timeout=timeout) 

345 

346 def __enter__(self) -> Self: 

347 self.connect() 

348 return self 

349 

350 def __exit__( 

351 self, 

352 exc_type: type[BaseException] | None, 

353 exc: BaseException | None, 

354 exc_tb: TracebackType | None, 

355 ) -> None: 

356 if self.is_connected(): 

357 self.disconnect() 

358 

359 @overload 

360 def get_metrics( 

361 self, 

362 format: Literal['json'] = 'json', 

363 *, 

364 global_labels: dict[str, str] | None = None, 

365 ) -> Metrics: ... 

366 

367 @overload 

368 def get_metrics( 

369 self, 

370 format: Literal['prometheus'], 

371 *, 

372 global_labels: dict[str, str] | None = None, 

373 ) -> str: ... 

374 

375 def get_metrics( 

376 self, 

377 format: MetricsFormat = 'json', 

378 *, 

379 global_labels: dict[str, str] | None = None, 

380 ) -> str | Metrics: 

381 """Metrics give you a detailed insight into how the Prisma Client interacts with your database. 

382 

383 You can retrieve metrics in either JSON or Prometheus formats. 

384 

385 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics. 

386 """ 

387 response = self._engine.metrics(format=format, global_labels=global_labels) 

388 if format == 'prometheus': 

389 # For the prometheus format we return the response as-is 

390 assert isinstance(response, str) 

391 return response 

392 

393 return model_parse(Metrics, response) 

394 

395 def _create_engine(self, dml_path: Path | None = None) -> SyncAbstractEngine: 

396 if self._engine_type == EngineType.binary: 396 ↛ 403line 396 didn't jump to line 403, because the condition on line 396 was never false

397 return SyncQueryEngine( 

398 dml_path=dml_path or self._packaged_schema_path, 

399 log_queries=self._log_queries, 

400 http_config=self._http_config, 

401 ) 

402 

403 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}') 

404 

405 @property 

406 def _engine_class(self) -> type[SyncAbstractEngine]: 

407 if self._engine_type == EngineType.binary: 

408 return SyncQueryEngine 

409 

410 raise RuntimeError(f'Unhandled engine type: {self._engine_type}') 

411 

412 # TODO: don't return Any 

413 def _execute( 

414 self, 

415 method: PrismaMethod, 

416 arguments: dict[str, Any], 

417 model: type[BaseModel] | None = None, 

418 root_selection: list[str] | None = None, 

419 ) -> Any: 

420 builder = self._make_query_builder( 

421 method=method, model=model, arguments=arguments, root_selection=root_selection 

422 ) 

423 return self._engine.query(builder.build(), tx_id=self._tx_id) 

424 

425 

426class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]): 

427 __slots__ = () 

428 

429 async def connect( 

430 self, 

431 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT, 

432 ) -> None: 

433 """Connect to the Prisma query engine. 

434 

435 It is required to call this before accessing data. 

436 """ 

437 if self._internal_engine is None: 

438 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path) 

439 

440 timeout, datasources = self._prepare_connect_args(timeout=timeout) 

441 

442 await self._internal_engine.connect( 

443 timeout=timeout, 

444 datasources=datasources, 

445 ) 

446 

447 async def disconnect(self, timeout: float | timedelta | None = None) -> None: 

448 """Disconnect the Prisma query engine.""" 

449 if self._internal_engine is not None: 449 ↛ exitline 449 didn't return from function 'disconnect', because the condition on line 449 was never false

450 engine = self._internal_engine 

451 self._internal_engine = None 

452 

453 if isinstance(timeout, (int, float)): 

454 message = ( 

455 'Passing a number as `timeout` argument is deprecated ' 

456 'and will be removed in the next major release. ' 

457 'Use a `datetime.timedelta` instead.' 

458 ) 

459 warnings.warn(message, DeprecationWarning, stacklevel=2) 

460 timeout = timedelta(seconds=timeout) 

461 

462 await engine.aclose(timeout=timeout) 

463 engine.stop(timeout=timeout) 

464 

465 async def __aenter__(self) -> Self: 

466 await self.connect() 

467 return self 

468 

469 async def __aexit__( 

470 self, 

471 exc_type: type[BaseException] | None, 

472 exc: BaseException | None, 

473 exc_tb: TracebackType | None, 

474 ) -> None: 

475 if self.is_connected(): 475 ↛ exitline 475 didn't return from function '__aexit__', because the condition on line 475 was never false

476 await self.disconnect() 

477 

478 @overload 

479 async def get_metrics( 

480 self, 

481 format: Literal['json'] = 'json', 

482 *, 

483 global_labels: dict[str, str] | None = None, 

484 ) -> Metrics: ... 

485 

486 @overload 

487 async def get_metrics( 

488 self, 

489 format: Literal['prometheus'], 

490 *, 

491 global_labels: dict[str, str] | None = None, 

492 ) -> str: ... 

493 

494 async def get_metrics( 

495 self, 

496 format: MetricsFormat = 'json', 

497 *, 

498 global_labels: dict[str, str] | None = None, 

499 ) -> str | Metrics: 

500 """Metrics give you a detailed insight into how the Prisma Client interacts with your database. 

501 

502 You can retrieve metrics in either JSON or Prometheus formats. 

503 

504 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics. 

505 """ 

506 response = await self._engine.metrics(format=format, global_labels=global_labels) 

507 if format == 'prometheus': 

508 # For the prometheus format we return the response as-is 

509 assert isinstance(response, str) 

510 return response 

511 

512 return model_parse(Metrics, response) 

513 

514 def _create_engine(self, dml_path: Path | None = None) -> AsyncAbstractEngine: 

515 if self._engine_type == EngineType.binary: 515 ↛ 522line 515 didn't jump to line 522, because the condition on line 515 was never false

516 return AsyncQueryEngine( 

517 dml_path=dml_path or self._packaged_schema_path, 

518 log_queries=self._log_queries, 

519 http_config=self._http_config, 

520 ) 

521 

522 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}') 

523 

524 @property 

525 def _engine_class(self) -> type[AsyncAbstractEngine]: 

526 if self._engine_type == EngineType.binary: 526 ↛ 529line 526 didn't jump to line 529, because the condition on line 526 was never false

527 return AsyncQueryEngine 

528 

529 raise RuntimeError(f'Unhandled engine type: {self._engine_type}') 

530 

531 # TODO: don't return Any 

532 async def _execute( 

533 self, 

534 *, 

535 method: PrismaMethod, 

536 arguments: dict[str, Any], 

537 model: type[BaseModel] | None = None, 

538 root_selection: list[str] | None = None, 

539 ) -> Any: 

540 builder = self._make_query_builder( 

541 method=method, model=model, arguments=arguments, root_selection=root_selection 

542 ) 

543 return await self._engine.query(builder.build(), tx_id=self._tx_id)