Coverage for databases/main.py: 97%

187 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from __future__ import annotations 

2 

3import os 

4import re 

5import json 

6import shlex 

7import shutil 

8import contextlib 

9from copy import deepcopy 

10from typing import ( 

11 Iterator, 

12 Optional, 

13 cast, 

14) 

15from pathlib import Path 

16from contextvars import ContextVar, copy_context 

17 

18import nox 

19import yaml 

20import click 

21import rtoml 

22import typer 

23from jinja2 import Environment, StrictUndefined, FileSystemLoader 

24from nox.command import CommandFailed 

25 

26from lib.utils import flatten, escape_path 

27from prisma._compat import model_copy, model_json, cached_property 

28from pipelines.utils import ( 

29 setup_coverage, 

30 get_pkg_location, 

31 maybe_install_nodejs_bin, 

32) 

33 

34from .utils import DatabaseConfig 

35from ._serve import start_database 

36from ._types import SupportedDatabase 

37from .constants import ( 

38 ROOT_DIR, 

39 TESTS_DIR, 

40 DATABASES_DIR, 

41 PYTEST_CONFIG, 

42 CONFIG_MAPPING, 

43 PYRIGHT_CONFIG, 

44 SYNC_TESTS_DIR, 

45 FEATURES_MAPPING, 

46 SUPPORTED_DATABASES, 

47) 

48 

49# TODO: switch to a pretty logging setup 

50# structlog? 

51 

52session_ctx: ContextVar[nox.Session] = ContextVar('session_ctx') 

53 

54 

55# TODO: proper progname in help 

56cli = typer.Typer( 

57 help='Test suite for testing Prisma Client Python against different database providers.', 

58) 

59 

60 

61@cli.command() 

62def test( 

63 *, 

64 databases: list[str] = cast('list[str]', SUPPORTED_DATABASES), # pyright: ignore[reportCallInDefaultInitializer] 

65 exclude_databases: list[str] = [], # pyright: ignore[reportCallInDefaultInitializer] 

66 inplace: bool = False, 

67 pytest_args: Optional[str] = None, 

68 lint: bool = True, 

69 test: bool = True, 

70 coverage: bool = False, 

71 pydantic_v2: bool = True, 

72 for_async: bool = typer.Option(default=True, is_flag=False), # pyright: ignore[reportCallInDefaultInitializer] 

73) -> None: 

74 """Run unit tests and Pyright""" 

75 if not pydantic_v2: 

76 lint = False 

77 

78 session = session_ctx.get() 

79 

80 exclude = set(validate_databases(exclude_databases)) 

81 validated_databases: list[SupportedDatabase] = [ 

82 database for database in validate_databases(databases) if database not in exclude 

83 ] 

84 

85 with session.chdir(DATABASES_DIR): 

86 _setup_test_env(session, pydantic_v2=pydantic_v2, inplace=inplace) 

87 

88 for database in validated_databases: 

89 print(title(CONFIG_MAPPING[database].name)) 

90 

91 # point coverage to store data in a database specific location 

92 # as to not overwrite any existing data from other database tests 

93 if coverage: # pragma: no branch 

94 setup_coverage(session, identifier=database) 

95 

96 runner = Runner(database=database, track_coverage=coverage, for_async=for_async) 

97 runner.setup() 

98 

99 if test: # pragma: no branch 

100 runner.test(pytest_args=pytest_args) 

101 

102 if lint: # pragma: no branch 

103 runner.lint() 

104 

105 

106@cli.command() 

107def serve(database: str, *, version: Optional[str] = None) -> None: 

108 """Start a database server using docker-compose""" 

109 database = validate_database(database) 

110 start_database(database, version=version, session=session_ctx.get()) 

111 

112 

113@cli.command(name='test-inverse') 

114def test_inverse( 

115 *, 

116 databases: list[str] = cast('list[str]', SUPPORTED_DATABASES), # pyright: ignore[reportCallInDefaultInitializer] 

117 coverage: bool = False, 

118 inplace: bool = False, 

119 pytest_args: Optional[str] = None, 

120 pydantic_v2: bool = True, 

121 for_async: bool = typer.Option(default=True, is_flag=False), # pyright: ignore[reportCallInDefaultInitializer] 

122) -> None: 

123 """Ensure unsupported features actually result in either: 

124 

125 - Prisma Schema validation failing 

126 - Our unit tests & linters fail 

127 """ 

128 session = session_ctx.get() 

129 validated_databases = validate_databases(databases) 

130 

131 with session.chdir(DATABASES_DIR): 

132 _setup_test_env(session, pydantic_v2=pydantic_v2, inplace=inplace) 

133 

134 for database in validated_databases: 

135 config = CONFIG_MAPPING[database] 

136 print(title(config.name)) 

137 

138 if not config.unsupported_features: 

139 print(f'There are no unsupported features for {database}.') 

140 continue 

141 

142 # point coverage to store data in a database specific location 

143 # as to not overwrite any existing data from other database tests 

144 if coverage: # pragma: no branch 

145 setup_coverage(session, identifier=database) 

146 

147 # TODO: support for tesing a given list of unsupported features 

148 for feature in config.unsupported_features: 

149 print(title(f'Testing {feature} feature')) 

150 

151 new_config = model_copy(config, deep=True) 

152 new_config.unsupported_features.remove(feature) 

153 

154 runner = Runner( 

155 database=database, 

156 config=new_config, 

157 for_async=for_async, 

158 track_coverage=coverage, 

159 ) 

160 

161 with raises_command({1}) as result: 

162 runner.setup() 

163 

164 if result.did_raise: 

165 print('Test setup failed (expectedly); Skipping pytest & pyright checks') 

166 continue 

167 

168 with raises_command({1}): 

169 runner.test(pytest_args=pytest_args) 

170 

171 with raises_command({1}): 

172 runner.lint() 

173 

174 click.echo( 

175 click.style( 

176 f'✅ All tests successfully failed for {database}', 

177 fg='green', 

178 ) 

179 ) 

180 

181 

182def _setup_test_env(session: nox.Session, *, pydantic_v2: bool, inplace: bool) -> None: 

183 if pydantic_v2: 

184 session.install('-r', '../pipelines/requirements/deps/pydantic.txt') 

185 else: 

186 session.install('pydantic==1.10.0') 

187 

188 session.install('-r', 'requirements.txt') 

189 maybe_install_nodejs_bin(session) 

190 

191 if inplace: # pragma: no cover 

192 # useful for updating the generated code so that Pylance picks it up 

193 session.install('-U', '-e', '..') 

194 else: 

195 session.install('-U', '..') 

196 

197 session.run('python', '-m', 'prisma', 'py', 'version') 

198 

199 

200class RaisesCommandResult: 

201 did_raise: bool 

202 

203 def __init__(self) -> None: 

204 self.did_raise = False 

205 

206 

207# matches nox's CommandFailed exception message 

208COMMAND_FAILED_RE = re.compile(r'Returned code (\d+)') 

209 

210 

211@contextlib.contextmanager 

212def raises_command( 

213 allowed_exit_codes: set[int], 

214) -> Iterator[RaisesCommandResult]: 

215 """Context manager that intercepts and ignores `nox.CommandFailed` exceptions 

216 that are raised due to known exit codes. All other exceptions are passed through. 

217 """ 

218 result = RaisesCommandResult() 

219 

220 try: 

221 yield result 

222 except CommandFailed as exc: 

223 match = COMMAND_FAILED_RE.match(exc.reason or '') 

224 if match is None: 224 ↛ 225line 224 didn't jump to line 225, because the condition on line 224 was never true

225 raise RuntimeError(f'Could not extract exit code from exception {exc}') from exc 

226 

227 exit_code = int(match.group(1)) 

228 if exit_code not in allowed_exit_codes: 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true

229 raise RuntimeError( 

230 f'Unknown code: {exit_code}; Something may have gone wrong ' 

231 + 'or this exit code must be added to the list of known exit codes; ' 

232 + f'Allowed exit codes: {allowed_exit_codes}' 

233 ) from exc 

234 

235 result.did_raise = True 

236 

237 

238class Runner: 

239 database: SupportedDatabase 

240 session: nox.Session 

241 config: DatabaseConfig 

242 cache_dir: Path 

243 track_coverage: bool 

244 for_async: bool 

245 

246 def __init__( 

247 self, 

248 *, 

249 database: SupportedDatabase, 

250 track_coverage: bool, 

251 for_async: bool, 

252 config: DatabaseConfig | None = None, 

253 ) -> None: 

254 self.database = database 

255 self.session = session_ctx.get() 

256 self.for_async = for_async 

257 self.track_coverage = track_coverage 

258 self.config = config or CONFIG_MAPPING[database] 

259 self.cache_dir = ROOT_DIR / '.tests_cache' / 'databases' / database 

260 

261 def _create_cache_dir(self) -> None: 

262 cache_dir = self.cache_dir 

263 if cache_dir.exists(): # pragma: no cover 

264 shutil.rmtree(cache_dir) 

265 

266 cache_dir.mkdir(parents=True, exist_ok=True) 

267 

268 def setup(self) -> None: 

269 # TODO: split up more 

270 print('database config: ' + model_json(self.config, indent=2)) 

271 print('for async: ', self.for_async) 

272 

273 exclude_files = self.exclude_files 

274 if exclude_files: 274 ↛ 277line 274 didn't jump to line 277, because the condition on line 274 was never false

275 print(f'excluding files:\n{yaml.dump(list(exclude_files))}') 

276 else: 

277 print('excluding files: []') 

278 

279 self._create_cache_dir() 

280 

281 # TODO: only create this if linting 

282 self._create_pyright_config() 

283 

284 # TODO: only create this if testing 

285 pytest_config = self.cache_dir / 'pyproject.toml' 

286 rtoml.dump(PYTEST_CONFIG, pytest_config, pretty=True) 

287 

288 # create a Prisma Schema file 

289 env = Environment( 

290 trim_blocks=True, 

291 lstrip_blocks=True, 

292 undefined=StrictUndefined, 

293 keep_trailing_newline=True, 

294 loader=FileSystemLoader(DATABASES_DIR / 'templates'), 

295 ) 

296 template = env.get_template('schema.prisma.jinja2') 

297 self.schema.write_text( 

298 template.render( 

299 # template variables 

300 config=self.config, 

301 for_async=self.for_async, 

302 partial_generator=escape_path(DATABASES_DIR / 'partials.py'), 

303 ) 

304 ) 

305 

306 # generate the client 

307 self.session.run(*self.python_args, 'prisma_cleanup') 

308 self.session.run( 

309 *self.python_args, 

310 'prisma', 

311 'generate', 

312 f'--schema={self.schema}', 

313 ) 

314 

315 def test(self, *, pytest_args: str | None) -> None: 

316 # ensure DB is in correct state 

317 self.session.run( 

318 *self.python_args, 

319 'prisma', 

320 'db', 

321 'push', 

322 '--force-reset', 

323 '--accept-data-loss', 

324 '--skip-generate', 

325 f'--schema={self.schema}', 

326 ) 

327 

328 args = [] 

329 if pytest_args is not None: # pragma: no cover 

330 args = shlex.split(pytest_args) 

331 

332 # TODO: use PYTEST_ADDOPTS instead 

333 self.session.run( 

334 *self.python_args, 

335 'pytest', 

336 *args, 

337 *map( 

338 lambda i: f'--ignore={i}', 

339 self.exclude_files, 

340 ), 

341 env={ 

342 'PRISMA_DATABASE': self.database, 

343 # TODO: this should be accessible in the core client 

344 'DATABASE_CONFIG': model_json(self.config), 

345 }, 

346 ) 

347 

348 def lint(self) -> None: 

349 self.session.run('pyright', '-p', str(self.pyright_config.absolute())) 

350 

351 def _create_pyright_config(self) -> None: 

352 pkg_location = os.path.relpath(get_pkg_location(session_ctx.get(), 'prisma'), DATABASES_DIR) 

353 

354 pyright_config = deepcopy(PYRIGHT_CONFIG) 

355 pyright_config['exclude'].extend(self.exclude_files) 

356 

357 # exclude the mypy plugin so that we don't have to install `mypy`, it is also 

358 # not dynamically generated which means it will stay the same across database providers 

359 pyright_config['exclude'].append(str(Path(pkg_location).joinpath('mypy.py'))) 

360 

361 # exclude our vendored code 

362 # this needs to explicitly be the package location as otherwise our `include` 

363 # of the package directory overrides other `exclude`s for the vendor dir 

364 pyright_config['exclude'].append(str(Path(pkg_location).joinpath('_vendor'))) 

365 

366 # add the generated client code to Pyright too 

367 pyright_config['include'].append(pkg_location) 

368 

369 # ensure only the tests for sync / async are checked 

370 pyright_config['include'].append(tests_reldir(for_async=self.for_async)) 

371 

372 self.pyright_config.write_text(json.dumps(pyright_config, indent=2)) 

373 

374 @cached_property 

375 def python_args(self) -> list[str]: 

376 return shlex.split('coverage run --rcfile=../.coveragerc -m' if self.track_coverage else 'python -m') 

377 

378 @property 

379 def pyright_config(self) -> Path: 

380 # TODO: move this to the cache dir, it requires some clever path renaming 

381 # as Pyright requires that `exclude` be relative to the location of the config file 

382 return DATABASES_DIR.joinpath(f'{self.database}.pyrightconfig.json') 

383 

384 @property 

385 def schema(self) -> Path: 

386 return self.cache_dir.joinpath('schema.prisma') 

387 

388 @cached_property 

389 def exclude_files(self) -> set[str]: 

390 files = [ 

391 tests_relpath(path, for_async=self.for_async) 

392 for path in flatten([FEATURES_MAPPING[feature] for feature in self.config.unsupported_features]) 

393 ] 

394 

395 # ensure the tests for the sync client are not ran during the async tests anc vice versa 

396 files.append(tests_reldir(for_async=not self.for_async)) 

397 

398 return set(files) 

399 

400 

401def validate_databases(databases: list[str]) -> list[SupportedDatabase]: 

402 # Typer by default requires that `List` options be specified multiple times, e.g. 

403 # `--databases=sqlite --databases=postgresql` 

404 # 

405 # I don't like this, I would much rather support this: 

406 # `--databases=sqlite,postgresql` 

407 # 

408 # I couldn't quickly find an option to support this with Typer so 

409 # it is handled manually here. 

410 databases = flatten([d.split(',') for d in databases]) 

411 return list(map(validate_database, databases)) 

412 

413 

414def validate_database(database: str) -> SupportedDatabase: 

415 # We convert the input to lowercase so that we don't have to define 

416 # two separate names in the CI matrix. 

417 database = database.lower() 

418 if database not in SUPPORTED_DATABASES: # pragma: no cover 

419 raise ValueError(f'Unknown database: {database}') 

420 

421 return database 

422 

423 

424def tests_reldir(*, for_async: bool) -> str: 

425 return str((TESTS_DIR if for_async else SYNC_TESTS_DIR).relative_to(DATABASES_DIR)) 

426 

427 

428def tests_relpath(path: str, *, for_async: bool) -> str: 

429 tests_dir = TESTS_DIR if for_async else SYNC_TESTS_DIR 

430 return str((tests_dir / path).relative_to(DATABASES_DIR)) 

431 

432 

433def title(text: str) -> str: 

434 # TODO: improve formatting 

435 dashes = '-' * 30 

436 return dashes + ' ' + click.style(text, bold=True) + ' ' + dashes 

437 

438 

439def entrypoint(session: nox.Session) -> None: 

440 """Wrapper over `cli()` that sets a `session` context variable for easier usage.""" 

441 

442 def wrapper() -> None: 

443 session_ctx.set(session) 

444 cli(session.posargs) 

445 

446 # copy the current context so that the session object is not leaked 

447 ctx = copy_context() 

448 return ctx.run(wrapper)