Coverage for src/prisma/engine/_query.py: 84%
213 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import os
4import sys
5import json
6import time
7import atexit
8import signal
9import asyncio
10import logging
11import subprocess
12from typing import TYPE_CHECKING, Any, overload
13from pathlib import Path
14from datetime import timedelta
15from typing_extensions import Literal, override
17from . import utils, errors
18from ._http import SyncHTTPEngine, AsyncHTTPEngine
19from ..utils import DEBUG, _env_bool, time_since
20from .._types import HttpConfig, TransactionId
21from .._builder import dumps
22from ..binaries import platform
23from .._constants import DEFAULT_CONNECT_TIMEOUT
25if TYPE_CHECKING:
26 from ..types import MetricsFormat, DatasourceOverride # noqa: TID251
29__all__ = (
30 'SyncQueryEngine',
31 'AsyncQueryEngine',
32)
34log: logging.Logger = logging.getLogger(__name__)
37class BaseQueryEngine:
38 dml_path: Path
39 url: str | None
40 file: Path | None
41 process: subprocess.Popen[bytes] | subprocess.Popen[str] | None
43 def __init__(
44 self,
45 *,
46 dml_path: Path,
47 log_queries: bool = False,
48 ) -> None:
49 self.dml_path = dml_path
50 self._log_queries = log_queries
51 self.process = None
52 self.file = None
54 def _ensure_file(self) -> Path:
55 # circular import
56 from ..client import BINARY_PATHS # noqa: TID251
58 return utils.ensure(BINARY_PATHS.query_engine)
60 def _spawn_process(
61 self,
62 *,
63 file: Path,
64 datasources: list[DatasourceOverride] | None,
65 ) -> tuple[str, subprocess.Popen[bytes] | subprocess.Popen[str]]:
66 port = utils.get_open_port()
67 log.debug('Running query engine on port %i', port)
69 self.url = f'http://localhost:{port}'
71 env = os.environ.copy()
72 env.update(
73 PRISMA_DML_PATH=str(self.dml_path.absolute()),
74 RUST_LOG='error',
75 RUST_LOG_FORMAT='json',
76 PRISMA_CLIENT_ENGINE_TYPE='binary',
77 PRISMA_ENGINE_PROTOCOL='graphql',
78 )
80 if DEBUG: 80 ↛ 83line 80 didn't jump to line 83, because the condition on line 80 was never false
81 env.update(RUST_LOG='info')
83 if datasources is not None:
84 env.update(OVERWRITE_DATASOURCES=dumps(datasources))
86 # TODO: remove the noise from these query logs
87 if self._log_queries:
88 env.update(LOG_QUERIES='y')
90 args: list[str] = [
91 str(file.absolute()),
92 '-p',
93 str(port),
94 '--enable-metrics',
95 '--enable-raw-queries',
96 ]
97 if _env_bool('__PRISMA_PY_PLAYGROUND'):
98 env.update(RUST_LOG='info')
99 args.append('--enable-playground')
101 log.debug('Starting query engine...')
102 popen_kwargs: dict[str, Any] = {
103 'env': env,
104 'stdout': sys.stdout,
105 'stderr': sys.stderr,
106 'text': False,
107 }
108 if platform.name() != 'windows':
109 # ensure SIGINT is unblocked before forking the query engine
110 # https://github.com/RobertCraigie/prisma-client-py/pull/678
111 popen_kwargs['preexec_fn'] = lambda: signal.pthread_sigmask( 111 ↛ exitline 111 didn't run the lambda on line 111
112 signal.SIG_UNBLOCK, [signal.SIGINT, signal.SIGTERM]
113 )
115 self.process = subprocess.Popen(args, **popen_kwargs)
117 return self.url, self.process
119 def _kill_process(self, timeout: timedelta | None) -> None:
120 if self.process is None:
121 return
123 if timeout is not None:
124 total_seconds = timeout.total_seconds()
125 else:
126 total_seconds = None
128 if platform.name() == 'windows':
129 self.process.kill()
130 self.process.wait(timeout=total_seconds)
131 else:
132 self.process.send_signal(signal.SIGINT)
133 try:
134 self.process.wait(timeout=total_seconds)
135 except subprocess.TimeoutExpired:
136 self.process.send_signal(signal.SIGKILL)
138 self.process = None
141class SyncQueryEngine(BaseQueryEngine, SyncHTTPEngine):
142 file: Path | None
144 def __init__(
145 self,
146 *,
147 dml_path: Path,
148 log_queries: bool = False,
149 http_config: HttpConfig | None = None,
150 ) -> None:
151 # this is a little weird but it's needed to distinguish between
152 # the different required arguments for our two base classes
153 BaseQueryEngine.__init__(self, dml_path=dml_path, log_queries=log_queries)
154 SyncHTTPEngine.__init__(self, url=None, **(http_config or {}))
156 # ensure the query engine process is terminated when we are
157 atexit.register(self.stop)
159 @override
160 def close(self, *, timeout: timedelta | None = None) -> None:
161 log.debug('Disconnecting query engine...')
163 self._kill_process(timeout=timeout)
164 self._close_session()
166 log.debug('Disconnected query engine')
168 @override
169 async def aclose(self, *, timeout: timedelta | None = None) -> None:
170 self.close(timeout=timeout)
171 self._close_session()
173 @override
174 def connect(
175 self,
176 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT,
177 datasources: list[DatasourceOverride] | None = None,
178 ) -> None:
179 log.debug('Connecting to query engine')
180 if datasources:
181 log.debug('Datasources: %s', datasources)
183 if self.process is not None: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true
184 raise errors.AlreadyConnectedError('Already connected to the query engine')
186 start = time.monotonic()
187 self.file = file = self._ensure_file()
189 try:
190 self.spawn(file, timeout=timeout, datasources=datasources)
191 except Exception:
192 self.close()
193 raise
195 log.debug('Connecting to query engine took %s', time_since(start))
197 def spawn(
198 self,
199 file: Path,
200 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT,
201 datasources: list[DatasourceOverride] | None = None,
202 ) -> None:
203 self._spawn_process(file=file, datasources=datasources)
205 last_exc = None
206 for _ in range(int(timeout.total_seconds() / 0.1)): 206 ↛ 229line 206 didn't jump to line 229, because the loop on line 206 didn't complete
207 try:
208 data = self.request('GET', '/status')
209 except Exception as exc:
210 # TODO(someday): only retry on ConnectionError
211 if isinstance(exc, AttributeError):
212 raise
214 last_exc = exc
215 log.debug(
216 'Could not connect to query engine due to %s; retrying...',
217 exc,
218 )
219 time.sleep(0.1)
220 continue
222 if data.get('Errors') is not None:
223 log.debug('Could not connect due to gql errors; retrying...')
224 time.sleep(0.1)
225 continue
227 break
228 else:
229 raise errors.EngineConnectionError('Could not connect to the query engine') from last_exc
231 @override
232 def query(
233 self,
234 content: str,
235 *,
236 tx_id: TransactionId | None,
237 ) -> Any:
238 headers: dict[str, str] = {}
239 if tx_id is not None:
240 headers['X-transaction-id'] = tx_id
242 return self.request(
243 'POST',
244 '/',
245 content=content,
246 headers=headers,
247 )
249 @override
250 def start_transaction(self, *, content: str) -> TransactionId:
251 result = self.request(
252 'POST',
253 '/transaction/start',
254 content=content,
255 )
256 return TransactionId(result['id'])
258 @override
259 def commit_transaction(self, tx_id: TransactionId) -> None:
260 self.request('POST', f'/transaction/{tx_id}/commit')
262 @override
263 def rollback_transaction(self, tx_id: TransactionId) -> None:
264 self.request('POST', f'/transaction/{tx_id}/rollback')
266 @overload
267 def metrics(
268 self,
269 *,
270 format: Literal['json'],
271 global_labels: dict[str, str] | None,
272 ) -> dict[str, Any]: ...
274 @overload
275 def metrics(
276 self,
277 *,
278 format: Literal['prometheus'],
279 global_labels: dict[str, str] | None,
280 ) -> str: ...
282 @override
283 def metrics(
284 self,
285 *,
286 format: MetricsFormat,
287 global_labels: dict[str, str] | None,
288 ) -> str | dict[str, Any]:
289 if global_labels is not None:
290 content = json.dumps(global_labels)
291 else:
292 content = None
294 return self.request( # type: ignore[no-any-return]
295 'GET',
296 f'/metrics?format={format}',
297 content=content,
298 parse_response=format == 'json',
299 )
302class AsyncQueryEngine(BaseQueryEngine, AsyncHTTPEngine):
303 file: Path | None
305 def __init__(
306 self,
307 *,
308 dml_path: Path,
309 log_queries: bool = False,
310 http_config: HttpConfig | None = None,
311 ) -> None:
312 # this is a little weird but it's needed to distinguish between
313 # the different required arguments for our two base classes
314 BaseQueryEngine.__init__(self, dml_path=dml_path, log_queries=log_queries)
315 AsyncHTTPEngine.__init__(self, url=None, **(http_config or {}))
317 # ensure the query engine process is terminated when we are
318 atexit.register(self.stop)
320 @override
321 def close(self, *, timeout: timedelta | None = None) -> None:
322 log.debug('Disconnecting query engine...')
324 self._kill_process(timeout=timeout)
326 log.debug('Disconnected query engine')
328 @override
329 async def aclose(self, *, timeout: timedelta | None = None) -> None:
330 self.close(timeout=timeout)
331 await self._close_session()
333 @override
334 async def connect(
335 self,
336 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT,
337 datasources: list[DatasourceOverride] | None = None,
338 ) -> None:
339 log.debug('Connecting to query engine')
340 if datasources:
341 log.debug('Datasources: %s', datasources)
343 if self.process is not None:
344 raise errors.AlreadyConnectedError('Already connected to the query engine')
346 start = time.monotonic()
347 self.file = file = self._ensure_file()
349 try:
350 await self.spawn(file, timeout=timeout, datasources=datasources)
351 except Exception:
352 self.close()
353 raise
355 log.debug('Connecting to query engine took %s', time_since(start))
357 async def spawn(
358 self,
359 file: Path,
360 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT,
361 datasources: list[DatasourceOverride] | None = None,
362 ) -> None:
363 self._spawn_process(file=file, datasources=datasources)
365 last_exc = None
366 for _ in range(int(timeout.total_seconds() / 0.1)): 366 ↛ 389line 366 didn't jump to line 389, because the loop on line 366 didn't complete
367 try:
368 data = await self.request('GET', '/status')
369 except Exception as exc:
370 # TODO(someday): only retry on ConnectionError
371 if isinstance(exc, AttributeError):
372 raise
374 last_exc = exc
375 log.debug(
376 'Could not connect to query engine due to %s; retrying...',
377 exc,
378 )
379 await asyncio.sleep(0.1)
380 continue
382 if data.get('Errors') is not None:
383 log.debug('Could not connect due to gql errors; retrying...')
384 await asyncio.sleep(0.1)
385 continue
387 break
388 else:
389 raise errors.EngineConnectionError('Could not connect to the query engine') from last_exc
391 @override
392 async def query(
393 self,
394 content: str,
395 *,
396 tx_id: TransactionId | None,
397 ) -> Any:
398 headers: dict[str, str] = {}
399 if tx_id is not None:
400 headers['X-transaction-id'] = tx_id
402 return await self.request(
403 'POST',
404 '/',
405 content=content,
406 headers=headers,
407 )
409 @override
410 async def start_transaction(self, *, content: str) -> TransactionId:
411 result = await self.request(
412 'POST',
413 '/transaction/start',
414 content=content,
415 )
416 return TransactionId(result['id'])
418 @override
419 async def commit_transaction(self, tx_id: TransactionId) -> None:
420 await self.request('POST', f'/transaction/{tx_id}/commit')
422 @override
423 async def rollback_transaction(self, tx_id: TransactionId) -> None:
424 await self.request('POST', f'/transaction/{tx_id}/rollback')
426 @overload
427 async def metrics(
428 self,
429 *,
430 format: Literal['json'],
431 global_labels: dict[str, str] | None,
432 ) -> dict[str, Any]: ...
434 @overload
435 async def metrics(
436 self,
437 *,
438 format: Literal['prometheus'],
439 global_labels: dict[str, str] | None,
440 ) -> str: ...
442 @override
443 async def metrics(
444 self,
445 *,
446 format: MetricsFormat,
447 global_labels: dict[str, str] | None,
448 ) -> str | dict[str, Any]:
449 if global_labels is not None:
450 content = json.dumps(global_labels)
451 else:
452 content = None
454 return await self.request( # type: ignore[no-any-return]
455 'GET',
456 f'/metrics?format={format}',
457 content=content,
458 parse_response=format == 'json',
459 )