Coverage for src/prisma/engine/_query.py: 84%

213 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1from __future__ import annotations 

2 

3import os 

4import sys 

5import json 

6import time 

7import atexit 

8import signal 

9import asyncio 

10import logging 

11import subprocess 

12from typing import TYPE_CHECKING, Any, overload 

13from pathlib import Path 

14from datetime import timedelta 

15from typing_extensions import Literal, override 

16 

17from . import utils, errors 

18from ._http import SyncHTTPEngine, AsyncHTTPEngine 

19from ..utils import DEBUG, _env_bool, time_since 

20from .._types import HttpConfig, TransactionId 

21from .._builder import dumps 

22from ..binaries import platform 

23from .._constants import DEFAULT_CONNECT_TIMEOUT 

24 

25if TYPE_CHECKING: 

26 from ..types import MetricsFormat, DatasourceOverride # noqa: TID251 

27 

28 

29__all__ = ( 

30 'SyncQueryEngine', 

31 'AsyncQueryEngine', 

32) 

33 

34log: logging.Logger = logging.getLogger(__name__) 

35 

36 

37class BaseQueryEngine: 

38 dml_path: Path 

39 url: str | None 

40 file: Path | None 

41 process: subprocess.Popen[bytes] | subprocess.Popen[str] | None 

42 

43 def __init__( 

44 self, 

45 *, 

46 dml_path: Path, 

47 log_queries: bool = False, 

48 ) -> None: 

49 self.dml_path = dml_path 

50 self._log_queries = log_queries 

51 self.process = None 

52 self.file = None 

53 

54 def _ensure_file(self) -> Path: 

55 # circular import 

56 from ..client import BINARY_PATHS # noqa: TID251 

57 

58 return utils.ensure(BINARY_PATHS.query_engine) 

59 

60 def _spawn_process( 

61 self, 

62 *, 

63 file: Path, 

64 datasources: list[DatasourceOverride] | None, 

65 ) -> tuple[str, subprocess.Popen[bytes] | subprocess.Popen[str]]: 

66 port = utils.get_open_port() 

67 log.debug('Running query engine on port %i', port) 

68 

69 self.url = f'http://localhost:{port}' 

70 

71 env = os.environ.copy() 

72 env.update( 

73 PRISMA_DML_PATH=str(self.dml_path.absolute()), 

74 RUST_LOG='error', 

75 RUST_LOG_FORMAT='json', 

76 PRISMA_CLIENT_ENGINE_TYPE='binary', 

77 PRISMA_ENGINE_PROTOCOL='graphql', 

78 ) 

79 

80 if DEBUG: 80 ↛ 83line 80 didn't jump to line 83, because the condition on line 80 was never false

81 env.update(RUST_LOG='info') 

82 

83 if datasources is not None: 

84 env.update(OVERWRITE_DATASOURCES=dumps(datasources)) 

85 

86 # TODO: remove the noise from these query logs 

87 if self._log_queries: 

88 env.update(LOG_QUERIES='y') 

89 

90 args: list[str] = [ 

91 str(file.absolute()), 

92 '-p', 

93 str(port), 

94 '--enable-metrics', 

95 '--enable-raw-queries', 

96 ] 

97 if _env_bool('__PRISMA_PY_PLAYGROUND'): 

98 env.update(RUST_LOG='info') 

99 args.append('--enable-playground') 

100 

101 log.debug('Starting query engine...') 

102 popen_kwargs: dict[str, Any] = { 

103 'env': env, 

104 'stdout': sys.stdout, 

105 'stderr': sys.stderr, 

106 'text': False, 

107 } 

108 if platform.name() != 'windows': 

109 # ensure SIGINT is unblocked before forking the query engine 

110 # https://github.com/RobertCraigie/prisma-client-py/pull/678 

111 popen_kwargs['preexec_fn'] = lambda: signal.pthread_sigmask( 111 ↛ exitline 111 didn't run the lambda on line 111

112 signal.SIG_UNBLOCK, [signal.SIGINT, signal.SIGTERM] 

113 ) 

114 

115 self.process = subprocess.Popen(args, **popen_kwargs) 

116 

117 return self.url, self.process 

118 

119 def _kill_process(self, timeout: timedelta | None) -> None: 

120 if self.process is None: 

121 return 

122 

123 if timeout is not None: 

124 total_seconds = timeout.total_seconds() 

125 else: 

126 total_seconds = None 

127 

128 if platform.name() == 'windows': 

129 self.process.kill() 

130 self.process.wait(timeout=total_seconds) 

131 else: 

132 self.process.send_signal(signal.SIGINT) 

133 try: 

134 self.process.wait(timeout=total_seconds) 

135 except subprocess.TimeoutExpired: 

136 self.process.send_signal(signal.SIGKILL) 

137 

138 self.process = None 

139 

140 

141class SyncQueryEngine(BaseQueryEngine, SyncHTTPEngine): 

142 file: Path | None 

143 

144 def __init__( 

145 self, 

146 *, 

147 dml_path: Path, 

148 log_queries: bool = False, 

149 http_config: HttpConfig | None = None, 

150 ) -> None: 

151 # this is a little weird but it's needed to distinguish between 

152 # the different required arguments for our two base classes 

153 BaseQueryEngine.__init__(self, dml_path=dml_path, log_queries=log_queries) 

154 SyncHTTPEngine.__init__(self, url=None, **(http_config or {})) 

155 

156 # ensure the query engine process is terminated when we are 

157 atexit.register(self.stop) 

158 

159 @override 

160 def close(self, *, timeout: timedelta | None = None) -> None: 

161 log.debug('Disconnecting query engine...') 

162 

163 self._kill_process(timeout=timeout) 

164 self._close_session() 

165 

166 log.debug('Disconnected query engine') 

167 

168 @override 

169 async def aclose(self, *, timeout: timedelta | None = None) -> None: 

170 self.close(timeout=timeout) 

171 self._close_session() 

172 

173 @override 

174 def connect( 

175 self, 

176 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT, 

177 datasources: list[DatasourceOverride] | None = None, 

178 ) -> None: 

179 log.debug('Connecting to query engine') 

180 if datasources: 

181 log.debug('Datasources: %s', datasources) 

182 

183 if self.process is not None: 183 ↛ 184line 183 didn't jump to line 184, because the condition on line 183 was never true

184 raise errors.AlreadyConnectedError('Already connected to the query engine') 

185 

186 start = time.monotonic() 

187 self.file = file = self._ensure_file() 

188 

189 try: 

190 self.spawn(file, timeout=timeout, datasources=datasources) 

191 except Exception: 

192 self.close() 

193 raise 

194 

195 log.debug('Connecting to query engine took %s', time_since(start)) 

196 

197 def spawn( 

198 self, 

199 file: Path, 

200 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT, 

201 datasources: list[DatasourceOverride] | None = None, 

202 ) -> None: 

203 self._spawn_process(file=file, datasources=datasources) 

204 

205 last_exc = None 

206 for _ in range(int(timeout.total_seconds() / 0.1)): 206 ↛ 229line 206 didn't jump to line 229, because the loop on line 206 didn't complete

207 try: 

208 data = self.request('GET', '/status') 

209 except Exception as exc: 

210 # TODO(someday): only retry on ConnectionError 

211 if isinstance(exc, AttributeError): 

212 raise 

213 

214 last_exc = exc 

215 log.debug( 

216 'Could not connect to query engine due to %s; retrying...', 

217 exc, 

218 ) 

219 time.sleep(0.1) 

220 continue 

221 

222 if data.get('Errors') is not None: 

223 log.debug('Could not connect due to gql errors; retrying...') 

224 time.sleep(0.1) 

225 continue 

226 

227 break 

228 else: 

229 raise errors.EngineConnectionError('Could not connect to the query engine') from last_exc 

230 

231 @override 

232 def query( 

233 self, 

234 content: str, 

235 *, 

236 tx_id: TransactionId | None, 

237 ) -> Any: 

238 headers: dict[str, str] = {} 

239 if tx_id is not None: 

240 headers['X-transaction-id'] = tx_id 

241 

242 return self.request( 

243 'POST', 

244 '/', 

245 content=content, 

246 headers=headers, 

247 ) 

248 

249 @override 

250 def start_transaction(self, *, content: str) -> TransactionId: 

251 result = self.request( 

252 'POST', 

253 '/transaction/start', 

254 content=content, 

255 ) 

256 return TransactionId(result['id']) 

257 

258 @override 

259 def commit_transaction(self, tx_id: TransactionId) -> None: 

260 self.request('POST', f'/transaction/{tx_id}/commit') 

261 

262 @override 

263 def rollback_transaction(self, tx_id: TransactionId) -> None: 

264 self.request('POST', f'/transaction/{tx_id}/rollback') 

265 

266 @overload 

267 def metrics( 

268 self, 

269 *, 

270 format: Literal['json'], 

271 global_labels: dict[str, str] | None, 

272 ) -> dict[str, Any]: ... 

273 

274 @overload 

275 def metrics( 

276 self, 

277 *, 

278 format: Literal['prometheus'], 

279 global_labels: dict[str, str] | None, 

280 ) -> str: ... 

281 

282 @override 

283 def metrics( 

284 self, 

285 *, 

286 format: MetricsFormat, 

287 global_labels: dict[str, str] | None, 

288 ) -> str | dict[str, Any]: 

289 if global_labels is not None: 

290 content = json.dumps(global_labels) 

291 else: 

292 content = None 

293 

294 return self.request( # type: ignore[no-any-return] 

295 'GET', 

296 f'/metrics?format={format}', 

297 content=content, 

298 parse_response=format == 'json', 

299 ) 

300 

301 

302class AsyncQueryEngine(BaseQueryEngine, AsyncHTTPEngine): 

303 file: Path | None 

304 

305 def __init__( 

306 self, 

307 *, 

308 dml_path: Path, 

309 log_queries: bool = False, 

310 http_config: HttpConfig | None = None, 

311 ) -> None: 

312 # this is a little weird but it's needed to distinguish between 

313 # the different required arguments for our two base classes 

314 BaseQueryEngine.__init__(self, dml_path=dml_path, log_queries=log_queries) 

315 AsyncHTTPEngine.__init__(self, url=None, **(http_config or {})) 

316 

317 # ensure the query engine process is terminated when we are 

318 atexit.register(self.stop) 

319 

320 @override 

321 def close(self, *, timeout: timedelta | None = None) -> None: 

322 log.debug('Disconnecting query engine...') 

323 

324 self._kill_process(timeout=timeout) 

325 

326 log.debug('Disconnected query engine') 

327 

328 @override 

329 async def aclose(self, *, timeout: timedelta | None = None) -> None: 

330 self.close(timeout=timeout) 

331 await self._close_session() 

332 

333 @override 

334 async def connect( 

335 self, 

336 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT, 

337 datasources: list[DatasourceOverride] | None = None, 

338 ) -> None: 

339 log.debug('Connecting to query engine') 

340 if datasources: 

341 log.debug('Datasources: %s', datasources) 

342 

343 if self.process is not None: 

344 raise errors.AlreadyConnectedError('Already connected to the query engine') 

345 

346 start = time.monotonic() 

347 self.file = file = self._ensure_file() 

348 

349 try: 

350 await self.spawn(file, timeout=timeout, datasources=datasources) 

351 except Exception: 

352 self.close() 

353 raise 

354 

355 log.debug('Connecting to query engine took %s', time_since(start)) 

356 

357 async def spawn( 

358 self, 

359 file: Path, 

360 timeout: timedelta = DEFAULT_CONNECT_TIMEOUT, 

361 datasources: list[DatasourceOverride] | None = None, 

362 ) -> None: 

363 self._spawn_process(file=file, datasources=datasources) 

364 

365 last_exc = None 

366 for _ in range(int(timeout.total_seconds() / 0.1)): 366 ↛ 389line 366 didn't jump to line 389, because the loop on line 366 didn't complete

367 try: 

368 data = await self.request('GET', '/status') 

369 except Exception as exc: 

370 # TODO(someday): only retry on ConnectionError 

371 if isinstance(exc, AttributeError): 

372 raise 

373 

374 last_exc = exc 

375 log.debug( 

376 'Could not connect to query engine due to %s; retrying...', 

377 exc, 

378 ) 

379 await asyncio.sleep(0.1) 

380 continue 

381 

382 if data.get('Errors') is not None: 

383 log.debug('Could not connect due to gql errors; retrying...') 

384 await asyncio.sleep(0.1) 

385 continue 

386 

387 break 

388 else: 

389 raise errors.EngineConnectionError('Could not connect to the query engine') from last_exc 

390 

391 @override 

392 async def query( 

393 self, 

394 content: str, 

395 *, 

396 tx_id: TransactionId | None, 

397 ) -> Any: 

398 headers: dict[str, str] = {} 

399 if tx_id is not None: 

400 headers['X-transaction-id'] = tx_id 

401 

402 return await self.request( 

403 'POST', 

404 '/', 

405 content=content, 

406 headers=headers, 

407 ) 

408 

409 @override 

410 async def start_transaction(self, *, content: str) -> TransactionId: 

411 result = await self.request( 

412 'POST', 

413 '/transaction/start', 

414 content=content, 

415 ) 

416 return TransactionId(result['id']) 

417 

418 @override 

419 async def commit_transaction(self, tx_id: TransactionId) -> None: 

420 await self.request('POST', f'/transaction/{tx_id}/commit') 

421 

422 @override 

423 async def rollback_transaction(self, tx_id: TransactionId) -> None: 

424 await self.request('POST', f'/transaction/{tx_id}/rollback') 

425 

426 @overload 

427 async def metrics( 

428 self, 

429 *, 

430 format: Literal['json'], 

431 global_labels: dict[str, str] | None, 

432 ) -> dict[str, Any]: ... 

433 

434 @overload 

435 async def metrics( 

436 self, 

437 *, 

438 format: Literal['prometheus'], 

439 global_labels: dict[str, str] | None, 

440 ) -> str: ... 

441 

442 @override 

443 async def metrics( 

444 self, 

445 *, 

446 format: MetricsFormat, 

447 global_labels: dict[str, str] | None, 

448 ) -> str | dict[str, Any]: 

449 if global_labels is not None: 

450 content = json.dumps(global_labels) 

451 else: 

452 content = None 

453 

454 return await self.request( # type: ignore[no-any-return] 

455 'GET', 

456 f'/metrics?format={format}', 

457 content=content, 

458 parse_response=format == 'json', 

459 )