Coverage for src/prisma/_base_client.py: 89%
218 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import logging
4import warnings
5from types import TracebackType
6from typing import Any, Generic, TypeVar, overload
7from pathlib import Path
8from datetime import timedelta
9from typing_extensions import Self, Literal
11from pydantic import BaseModel
13from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride
14from .engine import (
15 SyncQueryEngine,
16 AsyncQueryEngine,
17 BaseAbstractEngine,
18 SyncAbstractEngine,
19 AsyncAbstractEngine,
20)
21from .errors import ClientNotConnectedError, ClientNotRegisteredError
22from ._compat import model_parse, removeprefix
23from ._builder import QueryBuilder
24from ._metrics import Metrics
25from ._registry import get_client
26from .generator.models import EngineType
28log: logging.Logger = logging.getLogger(__name__)
31class UseClientDefault:
32 """For certain parameters such as `timeout=...` we can make our intent more clear
33 by typing the parameter with this class rather than using None, for example:
35 ```py
36 def connect(timeout: Union[int, timedelta, UseClientDefault] = UseClientDefault()) -> None: ...
37 ```
39 relays the intention more clearly than:
41 ```py
42 def connect(timeout: Union[int, timedelta, None] = None) -> None: ...
43 ```
45 This solution also allows us to indicate an "unset" state that is uniquely distinct
46 from `None` which may be useful in the future.
47 """
50USE_CLIENT_DEFAULT = UseClientDefault()
53def load_env(*, override: bool = False, **kwargs: Any) -> None:
54 """Load environemntal variables from dotenv files
56 Loads from the following files relative to the current
57 working directory:
59 - .env
60 - prisma/.env
61 """
62 from dotenv import load_dotenv
64 load_dotenv('.env', override=override, **kwargs)
65 load_dotenv('prisma/.env', override=override, **kwargs)
68_EngineT = TypeVar('_EngineT', bound=BaseAbstractEngine)
71class BasePrisma(Generic[_EngineT]):
72 _log_queries: bool
73 _datasource: DatasourceOverride | None
74 _connect_timeout: int | timedelta
75 _tx_id: TransactionId | None
76 _http_config: HttpConfig
77 _internal_engine: _EngineT | None
78 _copied: bool
80 # from generation
81 _schema_path: Path
82 _prisma_models: set[str]
83 _packaged_schema_path: Path
84 _engine_type: EngineType
85 _preview_features: set[str]
86 _default_datasource_name: str
87 _relational_field_mappings: dict[str, dict[str, str]]
89 __slots__ = (
90 '_copied',
91 '_tx_id',
92 '_datasource',
93 '_log_queries',
94 '_http_config',
95 '_schema_path',
96 '_engine_type',
97 '_prisma_models',
98 '_active_provider',
99 '_connect_timeout',
100 '_internal_engine',
101 '_packaged_schema_path',
102 '_preview_features',
103 '_default_datasource_name',
104 '_relational_field_mappings',
105 )
107 def __init__(
108 self,
109 *,
110 use_dotenv: bool,
111 log_queries: bool,
112 datasource: DatasourceOverride | None,
113 connect_timeout: int | timedelta,
114 http: HttpConfig | None,
115 ) -> None:
116 # NOTE: if you add any more properties here then you may also need to forward
117 # them in the `_copy()` method.
118 self._internal_engine = None
119 self._log_queries = log_queries
120 self._datasource = datasource
122 if isinstance(connect_timeout, int):
123 message = (
124 'Passing an int as `connect_timeout` argument is deprecated '
125 'and will be removed in the next major release. '
126 'Use a `datetime.timedelta` instance instead.'
127 )
128 warnings.warn(message, DeprecationWarning, stacklevel=2)
129 connect_timeout = timedelta(seconds=connect_timeout)
131 self._connect_timeout = connect_timeout
132 self._http_config: HttpConfig = http or {}
133 self._tx_id: TransactionId | None = None
134 self._copied: bool = False
136 if use_dotenv:
137 load_env()
139 def _set_generated_properties(
140 self,
141 *,
142 schema_path: Path,
143 engine_type: EngineType,
144 packaged_schema_path: Path,
145 active_provider: str,
146 prisma_models: set[str],
147 preview_features: set[str],
148 relational_field_mappings: dict[str, dict[str, str]],
149 default_datasource_name: str,
150 ) -> None:
151 """We pass through generated metadata using this method
152 instead of the `__init__()` because that causes weirdness
153 for our `_copy()` method as this base class has arguments
154 that the subclasses do not.
155 """
156 self._schema_path = schema_path
157 self._engine_type = engine_type
158 self._prisma_models = prisma_models
159 self._active_provider = active_provider
160 self._packaged_schema_path = packaged_schema_path
161 self._preview_features = preview_features
162 self._relational_field_mappings = relational_field_mappings
163 self._default_datasource_name = default_datasource_name
165 @property
166 def _default_datasource(self) -> Datasource:
167 raise NotImplementedError('`_default_datasource` should be implemented in a subclass')
169 def is_registered(self) -> bool:
170 """Returns True if this client instance is registered"""
171 try:
172 return get_client() is self
173 except ClientNotRegisteredError:
174 return False
176 def is_transaction(self) -> bool:
177 """Returns True if the client is wrapped within a transaction"""
178 return self._tx_id is not None
180 def is_connected(self) -> bool:
181 """Returns True if the client is connected to the query engine, False otherwise."""
182 return self._internal_engine is not None
184 def __del__(self) -> None:
185 # Note: as the transaction manager holds a reference to the original
186 # client as well as the transaction client the original client cannot
187 # be `free`d before the transaction is finished. So stopping the engine
188 # here should be safe.
189 if self._internal_engine is not None and not self._copied:
190 log.debug('unclosed client - stopping engine')
191 engine = self._internal_engine
192 self._internal_engine = None
193 engine.stop()
195 @property
196 def _engine(self) -> _EngineT:
197 engine = self._internal_engine
198 if engine is None:
199 raise ClientNotConnectedError()
200 return engine
202 @_engine.setter
203 def _engine(self, engine: _EngineT) -> None:
204 self._internal_engine = engine
206 def _copy(self) -> Self:
207 """Return a new Prisma instance using the same engine process (if connected).
209 This is only intended for private usage, there are no guarantees around this API.
210 """
211 new = self.__class__(
212 use_dotenv=False,
213 http=self._http_config,
214 datasource=self._datasource,
215 log_queries=self._log_queries,
216 connect_timeout=self._connect_timeout,
217 )
218 new._copied = True
220 if self._internal_engine is not None:
221 new._engine = self._internal_engine
223 return new
225 def _make_sqlite_datasource(self) -> DatasourceOverride:
226 """Override the default SQLite path to protect against
227 https://github.com/RobertCraigie/prisma-client-py/issues/409
228 """
229 source_file_path: str | Path | None = self._default_datasource.get('source_file_path')
230 if source_file_path: 230 ↛ 233line 230 didn't jump to line 233, because the condition on line 230 was never false
231 source_file_path = Path(source_file_path).parent
233 return {
234 'name': self._default_datasource['name'],
235 'url': self._make_sqlite_url(
236 self._default_datasource['url'],
237 relative_to=source_file_path,
238 ),
239 }
241 def _make_sqlite_url(self, url: str, *, relative_to: Path | str | None = None) -> str:
242 url_path = removeprefix(removeprefix(url, 'file:'), 'sqlite:')
243 if url_path == url:
244 return url
246 if Path(url_path).is_absolute():
247 return url
249 if relative_to is None:
250 relative_to = self._schema_path.parent
252 if isinstance(relative_to, str): 252 ↛ 253line 252 didn't jump to line 253, because the condition on line 252 was never true
253 relative_to = Path(relative_to)
255 return f'file:{relative_to.joinpath(url_path).resolve()}'
257 def _prepare_connect_args(
258 self,
259 *,
260 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
261 ) -> tuple[timedelta, list[DatasourceOverride] | None]:
262 """Returns (timeout, datasources) to be passed to `AbstractEngine.connect()`"""
263 if isinstance(timeout, UseClientDefault):
264 timeout = self._connect_timeout
266 if isinstance(timeout, int):
267 message = (
268 'Passing an int as `timeout` argument is deprecated '
269 'and will be removed in the next major release. '
270 'Use a `datetime.timedelta` instance instead.'
271 )
272 warnings.warn(message, DeprecationWarning, stacklevel=2)
273 timeout = timedelta(seconds=timeout)
275 datasources: list[DatasourceOverride] | None = None
276 if self._datasource is not None:
277 ds = self._datasource.copy()
278 ds.setdefault('name', self._default_datasource_name)
279 datasources = [ds]
280 elif self._active_provider == 'sqlite':
281 log.debug('overriding default SQLite datasource path')
282 # Override the default SQLite path to protect against
283 # https://github.com/RobertCraigie/prisma-client-py/issues/409
284 datasources = [self._make_sqlite_datasource()]
286 log.debug('datasources: %s', datasources)
287 return timeout, datasources
289 def _make_query_builder(
290 self,
291 *,
292 method: PrismaMethod,
293 arguments: dict[str, Any],
294 model: type[BaseModel] | None,
295 root_selection: list[str] | None,
296 ) -> QueryBuilder:
297 return QueryBuilder(
298 method=method,
299 model=model,
300 arguments=arguments,
301 root_selection=root_selection,
302 prisma_models=self._prisma_models,
303 relational_field_mappings=self._relational_field_mappings,
304 )
307class SyncBasePrisma(BasePrisma[SyncAbstractEngine]):
308 __slots__ = ()
310 def connect(
311 self,
312 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
313 ) -> None:
314 """Connect to the Prisma query engine.
316 It is required to call this before accessing data.
317 """
318 if self._internal_engine is None: 318 ↛ 321line 318 didn't jump to line 321, because the condition on line 318 was never false
319 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)
321 timeout, datasources = self._prepare_connect_args(timeout=timeout)
323 self._internal_engine.connect(
324 timeout=timeout,
325 datasources=datasources,
326 )
328 def disconnect(self, timeout: float | timedelta | None = None) -> None:
329 """Disconnect the Prisma query engine."""
330 if self._internal_engine is not None: 330 ↛ exitline 330 didn't return from function 'disconnect', because the condition on line 330 was never false
331 engine = self._internal_engine
332 self._internal_engine = None
334 if isinstance(timeout, (int, float)): 334 ↛ 335line 334 didn't jump to line 335
335 message = (
336 'Passing a number as `timeout` argument is deprecated '
337 'and will be removed in the next major release. '
338 'Use a `datetime.timedelta` instead.'
339 )
340 warnings.warn(message, DeprecationWarning, stacklevel=2)
341 timeout = timedelta(seconds=timeout)
343 engine.close(timeout=timeout)
344 engine.stop(timeout=timeout)
346 def __enter__(self) -> Self:
347 self.connect()
348 return self
350 def __exit__(
351 self,
352 exc_type: type[BaseException] | None,
353 exc: BaseException | None,
354 exc_tb: TracebackType | None,
355 ) -> None:
356 if self.is_connected():
357 self.disconnect()
359 @overload
360 def get_metrics(
361 self,
362 format: Literal['json'] = 'json',
363 *,
364 global_labels: dict[str, str] | None = None,
365 ) -> Metrics: ...
367 @overload
368 def get_metrics(
369 self,
370 format: Literal['prometheus'],
371 *,
372 global_labels: dict[str, str] | None = None,
373 ) -> str: ...
375 def get_metrics(
376 self,
377 format: MetricsFormat = 'json',
378 *,
379 global_labels: dict[str, str] | None = None,
380 ) -> str | Metrics:
381 """Metrics give you a detailed insight into how the Prisma Client interacts with your database.
383 You can retrieve metrics in either JSON or Prometheus formats.
385 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics.
386 """
387 response = self._engine.metrics(format=format, global_labels=global_labels)
388 if format == 'prometheus':
389 # For the prometheus format we return the response as-is
390 assert isinstance(response, str)
391 return response
393 return model_parse(Metrics, response)
395 def _create_engine(self, dml_path: Path | None = None) -> SyncAbstractEngine:
396 if self._engine_type == EngineType.binary: 396 ↛ 403line 396 didn't jump to line 403, because the condition on line 396 was never false
397 return SyncQueryEngine(
398 dml_path=dml_path or self._packaged_schema_path,
399 log_queries=self._log_queries,
400 http_config=self._http_config,
401 )
403 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')
405 @property
406 def _engine_class(self) -> type[SyncAbstractEngine]:
407 if self._engine_type == EngineType.binary:
408 return SyncQueryEngine
410 raise RuntimeError(f'Unhandled engine type: {self._engine_type}')
412 # TODO: don't return Any
413 def _execute(
414 self,
415 method: PrismaMethod,
416 arguments: dict[str, Any],
417 model: type[BaseModel] | None = None,
418 root_selection: list[str] | None = None,
419 ) -> Any:
420 builder = self._make_query_builder(
421 method=method, model=model, arguments=arguments, root_selection=root_selection
422 )
423 return self._engine.query(builder.build(), tx_id=self._tx_id)
426class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]):
427 __slots__ = ()
429 async def connect(
430 self,
431 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
432 ) -> None:
433 """Connect to the Prisma query engine.
435 It is required to call this before accessing data.
436 """
437 if self._internal_engine is None:
438 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)
440 timeout, datasources = self._prepare_connect_args(timeout=timeout)
442 await self._internal_engine.connect(
443 timeout=timeout,
444 datasources=datasources,
445 )
447 async def disconnect(self, timeout: float | timedelta | None = None) -> None:
448 """Disconnect the Prisma query engine."""
449 if self._internal_engine is not None: 449 ↛ exitline 449 didn't return from function 'disconnect', because the condition on line 449 was never false
450 engine = self._internal_engine
451 self._internal_engine = None
453 if isinstance(timeout, (int, float)):
454 message = (
455 'Passing a number as `timeout` argument is deprecated '
456 'and will be removed in the next major release. '
457 'Use a `datetime.timedelta` instead.'
458 )
459 warnings.warn(message, DeprecationWarning, stacklevel=2)
460 timeout = timedelta(seconds=timeout)
462 await engine.aclose(timeout=timeout)
463 engine.stop(timeout=timeout)
465 async def __aenter__(self) -> Self:
466 await self.connect()
467 return self
469 async def __aexit__(
470 self,
471 exc_type: type[BaseException] | None,
472 exc: BaseException | None,
473 exc_tb: TracebackType | None,
474 ) -> None:
475 if self.is_connected(): 475 ↛ exitline 475 didn't return from function '__aexit__', because the condition on line 475 was never false
476 await self.disconnect()
478 @overload
479 async def get_metrics(
480 self,
481 format: Literal['json'] = 'json',
482 *,
483 global_labels: dict[str, str] | None = None,
484 ) -> Metrics: ...
486 @overload
487 async def get_metrics(
488 self,
489 format: Literal['prometheus'],
490 *,
491 global_labels: dict[str, str] | None = None,
492 ) -> str: ...
494 async def get_metrics(
495 self,
496 format: MetricsFormat = 'json',
497 *,
498 global_labels: dict[str, str] | None = None,
499 ) -> str | Metrics:
500 """Metrics give you a detailed insight into how the Prisma Client interacts with your database.
502 You can retrieve metrics in either JSON or Prometheus formats.
504 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics.
505 """
506 response = await self._engine.metrics(format=format, global_labels=global_labels)
507 if format == 'prometheus':
508 # For the prometheus format we return the response as-is
509 assert isinstance(response, str)
510 return response
512 return model_parse(Metrics, response)
514 def _create_engine(self, dml_path: Path | None = None) -> AsyncAbstractEngine:
515 if self._engine_type == EngineType.binary: 515 ↛ 522line 515 didn't jump to line 522, because the condition on line 515 was never false
516 return AsyncQueryEngine(
517 dml_path=dml_path or self._packaged_schema_path,
518 log_queries=self._log_queries,
519 http_config=self._http_config,
520 )
522 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')
524 @property
525 def _engine_class(self) -> type[AsyncAbstractEngine]:
526 if self._engine_type == EngineType.binary: 526 ↛ 529line 526 didn't jump to line 529, because the condition on line 526 was never false
527 return AsyncQueryEngine
529 raise RuntimeError(f'Unhandled engine type: {self._engine_type}')
531 # TODO: don't return Any
532 async def _execute(
533 self,
534 *,
535 method: PrismaMethod,
536 arguments: dict[str, Any],
537 model: type[BaseModel] | None = None,
538 root_selection: list[str] | None = None,
539 ) -> Any:
540 builder = self._make_query_builder(
541 method=method, model=model, arguments=arguments, root_selection=root_selection
542 )
543 return await self._engine.query(builder.build(), tx_id=self._tx_id)