Coverage for src/prisma/_base_client.py: 90%
209 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
1from __future__ import annotations
3import logging
4import warnings
5from types import TracebackType
6from typing import Any, Generic, TypeVar, overload
7from pathlib import Path
8from datetime import timedelta
9from typing_extensions import Self, Literal
11from pydantic import BaseModel
13from ._types import Datasource, HttpConfig, PrismaMethod, MetricsFormat, TransactionId, DatasourceOverride
14from .engine import (
15 SyncQueryEngine,
16 AsyncQueryEngine,
17 BaseAbstractEngine,
18 SyncAbstractEngine,
19 AsyncAbstractEngine,
20)
21from .errors import ClientNotConnectedError, ClientNotRegisteredError
22from ._compat import model_parse, removeprefix
23from ._builder import QueryBuilder
24from ._metrics import Metrics
25from ._registry import get_client
26from .generator.models import EngineType
28log: logging.Logger = logging.getLogger(__name__)
31class UseClientDefault:
32 """For certain parameters such as `timeout=...` we can make our intent more clear
33 by typing the parameter with this class rather than using None, for example:
35 ```py
36 def connect(timeout: Union[int, timedelta, UseClientDefault] = UseClientDefault()) -> None: ...
37 ```
39 relays the intention more clearly than:
41 ```py
42 def connect(timeout: Union[int, timedelta, None] = None) -> None: ...
43 ```
45 This solution also allows us to indicate an "unset" state that is uniquely distinct
46 from `None` which may be useful in the future.
47 """
50USE_CLIENT_DEFAULT = UseClientDefault()
53def load_env(*, override: bool = False, **kwargs: Any) -> None:
54 """Load environemntal variables from dotenv files
56 Loads from the following files relative to the current
57 working directory:
59 - .env
60 - prisma/.env
61 """
62 from dotenv import load_dotenv
64 load_dotenv('.env', override=override, **kwargs)
65 load_dotenv('prisma/.env', override=override, **kwargs)
68_EngineT = TypeVar('_EngineT', bound=BaseAbstractEngine)
71class BasePrisma(Generic[_EngineT]):
72 _log_queries: bool
73 _datasource: DatasourceOverride | None
74 _connect_timeout: int | timedelta
75 _tx_id: TransactionId | None
76 _http_config: HttpConfig
77 _internal_engine: _EngineT | None
78 _copied: bool
80 # from generation
81 _schema_path: Path
82 _prisma_models: set[str]
83 _packaged_schema_path: Path
84 _engine_type: EngineType
85 _default_datasource_name: str
86 _relational_field_mappings: dict[str, dict[str, str]]
88 __slots__ = (
89 '_copied',
90 '_tx_id',
91 '_datasource',
92 '_log_queries',
93 '_http_config',
94 '_schema_path',
95 '_engine_type',
96 '_prisma_models',
97 '_active_provider',
98 '_connect_timeout',
99 '_internal_engine',
100 '_packaged_schema_path',
101 '_default_datasource_name',
102 '_relational_field_mappings',
103 )
105 def __init__(
106 self,
107 *,
108 use_dotenv: bool,
109 log_queries: bool,
110 datasource: DatasourceOverride | None,
111 connect_timeout: int | timedelta,
112 http: HttpConfig | None,
113 ) -> None:
114 # NOTE: if you add any more properties here then you may also need to forward
115 # them in the `_copy()` method.
116 self._internal_engine = None
117 self._log_queries = log_queries
118 self._datasource = datasource
120 if isinstance(connect_timeout, int):
121 message = (
122 'Passing an int as `connect_timeout` argument is deprecated '
123 'and will be removed in the next major release. '
124 'Use a `datetime.timedelta` instance instead.'
125 )
126 warnings.warn(message, DeprecationWarning, stacklevel=2)
127 connect_timeout = timedelta(seconds=connect_timeout)
129 self._connect_timeout = connect_timeout
130 self._http_config: HttpConfig = http or {}
131 self._tx_id: TransactionId | None = None
132 self._copied: bool = False
134 if use_dotenv:
135 load_env()
137 def _set_generated_properties(
138 self,
139 *,
140 schema_path: Path,
141 engine_type: EngineType,
142 packaged_schema_path: Path,
143 active_provider: str,
144 prisma_models: set[str],
145 relational_field_mappings: dict[str, dict[str, str]],
146 default_datasource_name: str,
147 ) -> None:
148 """We pass through generated metadata using this method
149 instead of the `__init__()` because that causes weirdness
150 for our `_copy()` method as this base class has arguments
151 that the subclasses do not.
152 """
153 self._schema_path = schema_path
154 self._engine_type = engine_type
155 self._prisma_models = prisma_models
156 self._active_provider = active_provider
157 self._packaged_schema_path = packaged_schema_path
158 self._relational_field_mappings = relational_field_mappings
159 self._default_datasource_name = default_datasource_name
161 @property
162 def _default_datasource(self) -> Datasource:
163 raise NotImplementedError('`_default_datasource` should be implemented in a subclass')
165 def is_registered(self) -> bool:
166 """Returns True if this client instance is registered"""
167 try:
168 return get_client() is self
169 except ClientNotRegisteredError:
170 return False
172 def is_transaction(self) -> bool:
173 """Returns True if the client is wrapped within a transaction"""
174 return self._tx_id is not None
176 def is_connected(self) -> bool:
177 """Returns True if the client is connected to the query engine, False otherwise."""
178 return self._internal_engine is not None
180 def __del__(self) -> None:
181 # Note: as the transaction manager holds a reference to the original
182 # client as well as the transaction client the original client cannot
183 # be `free`d before the transaction is finished. So stopping the engine
184 # here should be safe.
185 if self._internal_engine is not None and not self._copied:
186 log.debug('unclosed client - stopping engine')
187 engine = self._internal_engine
188 self._internal_engine = None
189 engine.stop()
191 @property
192 def _engine(self) -> _EngineT:
193 engine = self._internal_engine
194 if engine is None:
195 raise ClientNotConnectedError()
196 return engine
198 @_engine.setter
199 def _engine(self, engine: _EngineT) -> None:
200 self._internal_engine = engine
202 def _copy(self) -> Self:
203 """Return a new Prisma instance using the same engine process (if connected).
205 This is only intended for private usage, there are no guarantees around this API.
206 """
207 new = self.__class__(
208 use_dotenv=False,
209 http=self._http_config,
210 datasource=self._datasource,
211 log_queries=self._log_queries,
212 connect_timeout=self._connect_timeout,
213 )
214 new._copied = True
216 if self._internal_engine is not None:
217 new._engine = self._internal_engine
219 return new
221 def _make_sqlite_datasource(self) -> DatasourceOverride:
222 """Override the default SQLite path to protect against
223 https://github.com/RobertCraigie/prisma-client-py/issues/409
224 """
225 return {
226 'name': self._default_datasource['name'],
227 'url': self._make_sqlite_url(self._default_datasource['url']),
228 }
230 def _make_sqlite_url(self, url: str, *, relative_to: Path | None = None) -> str:
231 url_path = removeprefix(removeprefix(url, 'file:'), 'sqlite:')
232 if url_path == url:
233 return url
235 if Path(url_path).is_absolute():
236 return url
238 if relative_to is None:
239 relative_to = self._schema_path.parent
241 return f'file:{relative_to.joinpath(url_path).resolve()}'
243 def _prepare_connect_args(
244 self,
245 *,
246 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
247 ) -> tuple[timedelta, list[DatasourceOverride] | None]:
248 """Returns (timeout, datasources) to be passed to `AbstractEngine.connect()`"""
249 if isinstance(timeout, UseClientDefault):
250 timeout = self._connect_timeout
252 if isinstance(timeout, int):
253 message = (
254 'Passing an int as `timeout` argument is deprecated '
255 'and will be removed in the next major release. '
256 'Use a `datetime.timedelta` instance instead.'
257 )
258 warnings.warn(message, DeprecationWarning, stacklevel=2)
259 timeout = timedelta(seconds=timeout)
261 datasources: list[DatasourceOverride] | None = None
262 if self._datasource is not None:
263 ds = self._datasource.copy()
264 ds.setdefault('name', self._default_datasource_name)
265 datasources = [ds]
266 elif self._active_provider == 'sqlite':
267 # Override the default SQLite path to protect against
268 # https://github.com/RobertCraigie/prisma-client-py/issues/409
269 datasources = [self._make_sqlite_datasource()]
271 return timeout, datasources
273 def _make_query_builder(
274 self,
275 *,
276 method: PrismaMethod,
277 arguments: dict[str, Any],
278 model: type[BaseModel] | None,
279 root_selection: list[str] | None,
280 ) -> QueryBuilder:
281 return QueryBuilder(
282 method=method,
283 model=model,
284 arguments=arguments,
285 root_selection=root_selection,
286 prisma_models=self._prisma_models,
287 relational_field_mappings=self._relational_field_mappings,
288 )
291class SyncBasePrisma(BasePrisma[SyncAbstractEngine]):
292 __slots__ = ()
294 def connect(
295 self,
296 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
297 ) -> None:
298 """Connect to the Prisma query engine.
300 It is required to call this before accessing data.
301 """
302 if self._internal_engine is None: 302 ↛ 305line 302 didn't jump to line 305, because the condition on line 302 was never false
303 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)
305 timeout, datasources = self._prepare_connect_args(timeout=timeout)
307 self._internal_engine.connect(
308 timeout=timeout,
309 datasources=datasources,
310 )
312 def disconnect(self, timeout: float | timedelta | None = None) -> None:
313 """Disconnect the Prisma query engine."""
314 if self._internal_engine is not None: 314 ↛ exitline 314 didn't return from function 'disconnect', because the condition on line 314 was never false
315 engine = self._internal_engine
316 self._internal_engine = None
318 if isinstance(timeout, (int, float)): 318 ↛ 319line 318 didn't jump to line 319
319 message = (
320 'Passing a number as `timeout` argument is deprecated '
321 'and will be removed in the next major release. '
322 'Use a `datetime.timedelta` instead.'
323 )
324 warnings.warn(message, DeprecationWarning, stacklevel=2)
325 timeout = timedelta(seconds=timeout)
327 engine.close(timeout=timeout)
328 engine.stop(timeout=timeout)
330 def __enter__(self) -> Self:
331 self.connect()
332 return self
334 def __exit__(
335 self,
336 exc_type: type[BaseException] | None,
337 exc: BaseException | None,
338 exc_tb: TracebackType | None,
339 ) -> None:
340 if self.is_connected():
341 self.disconnect()
343 @overload
344 def get_metrics(
345 self,
346 format: Literal['json'] = 'json',
347 *,
348 global_labels: dict[str, str] | None = None,
349 ) -> Metrics: ...
351 @overload
352 def get_metrics(
353 self,
354 format: Literal['prometheus'],
355 *,
356 global_labels: dict[str, str] | None = None,
357 ) -> str: ...
359 def get_metrics(
360 self,
361 format: MetricsFormat = 'json',
362 *,
363 global_labels: dict[str, str] | None = None,
364 ) -> str | Metrics:
365 """Metrics give you a detailed insight into how the Prisma Client interacts with your database.
367 You can retrieve metrics in either JSON or Prometheus formats.
369 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics.
370 """
371 response = self._engine.metrics(format=format, global_labels=global_labels)
372 if format == 'prometheus':
373 # For the prometheus format we return the response as-is
374 assert isinstance(response, str)
375 return response
377 return model_parse(Metrics, response)
379 def _create_engine(self, dml_path: Path | None = None) -> SyncAbstractEngine:
380 if self._engine_type == EngineType.binary: 380 ↛ 387line 380 didn't jump to line 387, because the condition on line 380 was never false
381 return SyncQueryEngine(
382 dml_path=dml_path or self._packaged_schema_path,
383 log_queries=self._log_queries,
384 http_config=self._http_config,
385 )
387 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')
389 @property
390 def _engine_class(self) -> type[SyncAbstractEngine]:
391 if self._engine_type == EngineType.binary:
392 return SyncQueryEngine
394 raise RuntimeError(f'Unhandled engine type: {self._engine_type}')
396 # TODO: don't return Any
397 def _execute(
398 self,
399 method: PrismaMethod,
400 arguments: dict[str, Any],
401 model: type[BaseModel] | None = None,
402 root_selection: list[str] | None = None,
403 ) -> Any:
404 builder = self._make_query_builder(
405 method=method, model=model, arguments=arguments, root_selection=root_selection
406 )
407 return self._engine.query(builder.build(), tx_id=self._tx_id)
410class AsyncBasePrisma(BasePrisma[AsyncAbstractEngine]):
411 __slots__ = ()
413 async def connect(
414 self,
415 timeout: int | timedelta | UseClientDefault = USE_CLIENT_DEFAULT,
416 ) -> None:
417 """Connect to the Prisma query engine.
419 It is required to call this before accessing data.
420 """
421 if self._internal_engine is None:
422 self._internal_engine = self._create_engine(dml_path=self._packaged_schema_path)
424 timeout, datasources = self._prepare_connect_args(timeout=timeout)
426 await self._internal_engine.connect(
427 timeout=timeout,
428 datasources=datasources,
429 )
431 async def disconnect(self, timeout: float | timedelta | None = None) -> None:
432 """Disconnect the Prisma query engine."""
433 if self._internal_engine is not None: 433 ↛ exitline 433 didn't return from function 'disconnect', because the condition on line 433 was never false
434 engine = self._internal_engine
435 self._internal_engine = None
437 if isinstance(timeout, (int, float)):
438 message = (
439 'Passing a number as `timeout` argument is deprecated '
440 'and will be removed in the next major release. '
441 'Use a `datetime.timedelta` instead.'
442 )
443 warnings.warn(message, DeprecationWarning, stacklevel=2)
444 timeout = timedelta(seconds=timeout)
446 await engine.aclose(timeout=timeout)
447 engine.stop(timeout=timeout)
449 async def __aenter__(self) -> Self:
450 await self.connect()
451 return self
453 async def __aexit__(
454 self,
455 exc_type: type[BaseException] | None,
456 exc: BaseException | None,
457 exc_tb: TracebackType | None,
458 ) -> None:
459 if self.is_connected(): 459 ↛ exitline 459 didn't return from function '__aexit__', because the condition on line 459 was never false
460 await self.disconnect()
462 @overload
463 async def get_metrics(
464 self,
465 format: Literal['json'] = 'json',
466 *,
467 global_labels: dict[str, str] | None = None,
468 ) -> Metrics: ...
470 @overload
471 async def get_metrics(
472 self,
473 format: Literal['prometheus'],
474 *,
475 global_labels: dict[str, str] | None = None,
476 ) -> str: ...
478 async def get_metrics(
479 self,
480 format: MetricsFormat = 'json',
481 *,
482 global_labels: dict[str, str] | None = None,
483 ) -> str | Metrics:
484 """Metrics give you a detailed insight into how the Prisma Client interacts with your database.
486 You can retrieve metrics in either JSON or Prometheus formats.
488 For more details see https://www.prisma.io/docs/concepts/components/prisma-client/metrics.
489 """
490 response = await self._engine.metrics(format=format, global_labels=global_labels)
491 if format == 'prometheus':
492 # For the prometheus format we return the response as-is
493 assert isinstance(response, str)
494 return response
496 return model_parse(Metrics, response)
498 def _create_engine(self, dml_path: Path | None = None) -> AsyncAbstractEngine:
499 if self._engine_type == EngineType.binary: 499 ↛ 506line 499 didn't jump to line 506, because the condition on line 499 was never false
500 return AsyncQueryEngine(
501 dml_path=dml_path or self._packaged_schema_path,
502 log_queries=self._log_queries,
503 http_config=self._http_config,
504 )
506 raise NotImplementedError(f'Unsupported engine type: {self._engine_type}')
508 @property
509 def _engine_class(self) -> type[AsyncAbstractEngine]:
510 if self._engine_type == EngineType.binary: 510 ↛ 513line 510 didn't jump to line 513, because the condition on line 510 was never false
511 return AsyncQueryEngine
513 raise RuntimeError(f'Unhandled engine type: {self._engine_type}')
515 # TODO: don't return Any
516 async def _execute(
517 self,
518 *,
519 method: PrismaMethod,
520 arguments: dict[str, Any],
521 model: type[BaseModel] | None = None,
522 root_selection: list[str] | None = None,
523 ) -> Any:
524 builder = self._make_query_builder(
525 method=method, model=model, arguments=arguments, root_selection=root_selection
526 )
527 return await self._engine.query(builder.build(), tx_id=self._tx_id)