Coverage for databases/main.py: 97%
187 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import os
4import re
5import json
6import shlex
7import shutil
8import contextlib
9from copy import deepcopy
10from typing import (
11 Iterator,
12 Optional,
13 cast,
14)
15from pathlib import Path
16from contextvars import ContextVar, copy_context
18import nox
19import yaml
20import click
21import rtoml
22import typer
23from jinja2 import Environment, StrictUndefined, FileSystemLoader
24from nox.command import CommandFailed
26from lib.utils import flatten, escape_path
27from prisma._compat import model_copy, model_json, cached_property
28from pipelines.utils import (
29 setup_coverage,
30 get_pkg_location,
31 maybe_install_nodejs_bin,
32)
34from .utils import DatabaseConfig
35from ._serve import start_database
36from ._types import SupportedDatabase
37from .constants import (
38 ROOT_DIR,
39 TESTS_DIR,
40 DATABASES_DIR,
41 PYTEST_CONFIG,
42 CONFIG_MAPPING,
43 PYRIGHT_CONFIG,
44 SYNC_TESTS_DIR,
45 FEATURES_MAPPING,
46 SUPPORTED_DATABASES,
47)
49# TODO: switch to a pretty logging setup
50# structlog?
52session_ctx: ContextVar[nox.Session] = ContextVar('session_ctx')
55# TODO: proper progname in help
56cli = typer.Typer(
57 help='Test suite for testing Prisma Client Python against different database providers.',
58)
61@cli.command()
62def test(
63 *,
64 databases: list[str] = cast('list[str]', SUPPORTED_DATABASES), # pyright: ignore[reportCallInDefaultInitializer]
65 exclude_databases: list[str] = [], # pyright: ignore[reportCallInDefaultInitializer]
66 inplace: bool = False,
67 pytest_args: Optional[str] = None,
68 lint: bool = True,
69 test: bool = True,
70 coverage: bool = False,
71 pydantic_v2: bool = True,
72 for_async: bool = typer.Option(default=True, is_flag=False), # pyright: ignore[reportCallInDefaultInitializer]
73) -> None:
74 """Run unit tests and Pyright"""
75 if not pydantic_v2:
76 lint = False
78 session = session_ctx.get()
80 exclude = set(validate_databases(exclude_databases))
81 validated_databases: list[SupportedDatabase] = [
82 database for database in validate_databases(databases) if database not in exclude
83 ]
85 with session.chdir(DATABASES_DIR):
86 _setup_test_env(session, pydantic_v2=pydantic_v2, inplace=inplace)
88 for database in validated_databases:
89 print(title(CONFIG_MAPPING[database].name))
91 # point coverage to store data in a database specific location
92 # as to not overwrite any existing data from other database tests
93 if coverage: # pragma: no branch
94 setup_coverage(session, identifier=database)
96 runner = Runner(database=database, track_coverage=coverage, for_async=for_async)
97 runner.setup()
99 if test: # pragma: no branch
100 runner.test(pytest_args=pytest_args)
102 if lint: # pragma: no branch
103 runner.lint()
106@cli.command()
107def serve(database: str, *, version: Optional[str] = None) -> None:
108 """Start a database server using docker-compose"""
109 database = validate_database(database)
110 start_database(database, version=version, session=session_ctx.get())
113@cli.command(name='test-inverse')
114def test_inverse(
115 *,
116 databases: list[str] = cast('list[str]', SUPPORTED_DATABASES), # pyright: ignore[reportCallInDefaultInitializer]
117 coverage: bool = False,
118 inplace: bool = False,
119 pytest_args: Optional[str] = None,
120 pydantic_v2: bool = True,
121 for_async: bool = typer.Option(default=True, is_flag=False), # pyright: ignore[reportCallInDefaultInitializer]
122) -> None:
123 """Ensure unsupported features actually result in either:
125 - Prisma Schema validation failing
126 - Our unit tests & linters fail
127 """
128 session = session_ctx.get()
129 validated_databases = validate_databases(databases)
131 with session.chdir(DATABASES_DIR):
132 _setup_test_env(session, pydantic_v2=pydantic_v2, inplace=inplace)
134 for database in validated_databases:
135 config = CONFIG_MAPPING[database]
136 print(title(config.name))
138 if not config.unsupported_features:
139 print(f'There are no unsupported features for {database}.')
140 continue
142 # point coverage to store data in a database specific location
143 # as to not overwrite any existing data from other database tests
144 if coverage: # pragma: no branch
145 setup_coverage(session, identifier=database)
147 # TODO: support for tesing a given list of unsupported features
148 for feature in config.unsupported_features:
149 print(title(f'Testing {feature} feature'))
151 new_config = model_copy(config, deep=True)
152 new_config.unsupported_features.remove(feature)
154 runner = Runner(
155 database=database,
156 config=new_config,
157 for_async=for_async,
158 track_coverage=coverage,
159 )
161 with raises_command({1}) as result:
162 runner.setup()
164 if result.did_raise:
165 print('Test setup failed (expectedly); Skipping pytest & pyright checks')
166 continue
168 with raises_command({1}):
169 runner.test(pytest_args=pytest_args)
171 with raises_command({1}):
172 runner.lint()
174 click.echo(
175 click.style(
176 f'✅ All tests successfully failed for {database}',
177 fg='green',
178 )
179 )
182def _setup_test_env(session: nox.Session, *, pydantic_v2: bool, inplace: bool) -> None:
183 if pydantic_v2:
184 session.install('-r', '../pipelines/requirements/deps/pydantic.txt')
185 else:
186 session.install('pydantic==1.10.0')
188 session.install('-r', 'requirements.txt')
189 maybe_install_nodejs_bin(session)
191 if inplace: # pragma: no cover
192 # useful for updating the generated code so that Pylance picks it up
193 session.install('-U', '-e', '..')
194 else:
195 session.install('-U', '..')
197 session.run('python', '-m', 'prisma', 'py', 'version')
200class RaisesCommandResult:
201 did_raise: bool
203 def __init__(self) -> None:
204 self.did_raise = False
207# matches nox's CommandFailed exception message
208COMMAND_FAILED_RE = re.compile(r'Returned code (\d+)')
211@contextlib.contextmanager
212def raises_command(
213 allowed_exit_codes: set[int],
214) -> Iterator[RaisesCommandResult]:
215 """Context manager that intercepts and ignores `nox.CommandFailed` exceptions
216 that are raised due to known exit codes. All other exceptions are passed through.
217 """
218 result = RaisesCommandResult()
220 try:
221 yield result
222 except CommandFailed as exc:
223 match = COMMAND_FAILED_RE.match(exc.reason or '')
224 if match is None: 224 ↛ 225line 224 didn't jump to line 225, because the condition on line 224 was never true
225 raise RuntimeError(f'Could not extract exit code from exception {exc}') from exc
227 exit_code = int(match.group(1))
228 if exit_code not in allowed_exit_codes: 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true
229 raise RuntimeError(
230 f'Unknown code: {exit_code}; Something may have gone wrong '
231 + 'or this exit code must be added to the list of known exit codes; '
232 + f'Allowed exit codes: {allowed_exit_codes}'
233 ) from exc
235 result.did_raise = True
238class Runner:
239 database: SupportedDatabase
240 session: nox.Session
241 config: DatabaseConfig
242 cache_dir: Path
243 track_coverage: bool
244 for_async: bool
246 def __init__(
247 self,
248 *,
249 database: SupportedDatabase,
250 track_coverage: bool,
251 for_async: bool,
252 config: DatabaseConfig | None = None,
253 ) -> None:
254 self.database = database
255 self.session = session_ctx.get()
256 self.for_async = for_async
257 self.track_coverage = track_coverage
258 self.config = config or CONFIG_MAPPING[database]
259 self.cache_dir = ROOT_DIR / '.tests_cache' / 'databases' / database
261 def _create_cache_dir(self) -> None:
262 cache_dir = self.cache_dir
263 if cache_dir.exists(): # pragma: no cover
264 shutil.rmtree(cache_dir)
266 cache_dir.mkdir(parents=True, exist_ok=True)
268 def setup(self) -> None:
269 # TODO: split up more
270 print('database config: ' + model_json(self.config, indent=2))
271 print('for async: ', self.for_async)
273 exclude_files = self.exclude_files
274 if exclude_files: 274 ↛ 277line 274 didn't jump to line 277, because the condition on line 274 was never false
275 print(f'excluding files:\n{yaml.dump(list(exclude_files))}')
276 else:
277 print('excluding files: []')
279 self._create_cache_dir()
281 # TODO: only create this if linting
282 self._create_pyright_config()
284 # TODO: only create this if testing
285 pytest_config = self.cache_dir / 'pyproject.toml'
286 rtoml.dump(PYTEST_CONFIG, pytest_config, pretty=True)
288 # create a Prisma Schema file
289 env = Environment(
290 trim_blocks=True,
291 lstrip_blocks=True,
292 undefined=StrictUndefined,
293 keep_trailing_newline=True,
294 loader=FileSystemLoader(DATABASES_DIR / 'templates'),
295 )
296 template = env.get_template('schema.prisma.jinja2')
297 self.schema.write_text(
298 template.render(
299 # template variables
300 config=self.config,
301 for_async=self.for_async,
302 partial_generator=escape_path(DATABASES_DIR / 'partials.py'),
303 )
304 )
306 # generate the client
307 self.session.run(*self.python_args, 'prisma_cleanup')
308 self.session.run(
309 *self.python_args,
310 'prisma',
311 'generate',
312 f'--schema={self.schema}',
313 )
315 def test(self, *, pytest_args: str | None) -> None:
316 # ensure DB is in correct state
317 self.session.run(
318 *self.python_args,
319 'prisma',
320 'db',
321 'push',
322 '--force-reset',
323 '--accept-data-loss',
324 '--skip-generate',
325 f'--schema={self.schema}',
326 )
328 args = []
329 if pytest_args is not None: # pragma: no cover
330 args = shlex.split(pytest_args)
332 # TODO: use PYTEST_ADDOPTS instead
333 self.session.run(
334 *self.python_args,
335 'pytest',
336 *args,
337 *map(
338 lambda i: f'--ignore={i}',
339 self.exclude_files,
340 ),
341 env={
342 'PRISMA_DATABASE': self.database,
343 # TODO: this should be accessible in the core client
344 'DATABASE_CONFIG': model_json(self.config),
345 },
346 )
348 def lint(self) -> None:
349 self.session.run('pyright', '-p', str(self.pyright_config.absolute()))
351 def _create_pyright_config(self) -> None:
352 pkg_location = os.path.relpath(get_pkg_location(session_ctx.get(), 'prisma'), DATABASES_DIR)
354 pyright_config = deepcopy(PYRIGHT_CONFIG)
355 pyright_config['exclude'].extend(self.exclude_files)
357 # exclude the mypy plugin so that we don't have to install `mypy`, it is also
358 # not dynamically generated which means it will stay the same across database providers
359 pyright_config['exclude'].append(str(Path(pkg_location).joinpath('mypy.py')))
361 # exclude our vendored code
362 # this needs to explicitly be the package location as otherwise our `include`
363 # of the package directory overrides other `exclude`s for the vendor dir
364 pyright_config['exclude'].append(str(Path(pkg_location).joinpath('_vendor')))
366 # add the generated client code to Pyright too
367 pyright_config['include'].append(pkg_location)
369 # ensure only the tests for sync / async are checked
370 pyright_config['include'].append(tests_reldir(for_async=self.for_async))
372 self.pyright_config.write_text(json.dumps(pyright_config, indent=2))
374 @cached_property
375 def python_args(self) -> list[str]:
376 return shlex.split('coverage run --rcfile=../.coveragerc -m' if self.track_coverage else 'python -m')
378 @property
379 def pyright_config(self) -> Path:
380 # TODO: move this to the cache dir, it requires some clever path renaming
381 # as Pyright requires that `exclude` be relative to the location of the config file
382 return DATABASES_DIR.joinpath(f'{self.database}.pyrightconfig.json')
384 @property
385 def schema(self) -> Path:
386 return self.cache_dir.joinpath('schema.prisma')
388 @cached_property
389 def exclude_files(self) -> set[str]:
390 files = [
391 tests_relpath(path, for_async=self.for_async)
392 for path in flatten([FEATURES_MAPPING[feature] for feature in self.config.unsupported_features])
393 ]
395 # ensure the tests for the sync client are not ran during the async tests anc vice versa
396 files.append(tests_reldir(for_async=not self.for_async))
398 return set(files)
401def validate_databases(databases: list[str]) -> list[SupportedDatabase]:
402 # Typer by default requires that `List` options be specified multiple times, e.g.
403 # `--databases=sqlite --databases=postgresql`
404 #
405 # I don't like this, I would much rather support this:
406 # `--databases=sqlite,postgresql`
407 #
408 # I couldn't quickly find an option to support this with Typer so
409 # it is handled manually here.
410 databases = flatten([d.split(',') for d in databases])
411 return list(map(validate_database, databases))
414def validate_database(database: str) -> SupportedDatabase:
415 # We convert the input to lowercase so that we don't have to define
416 # two separate names in the CI matrix.
417 database = database.lower()
418 if database not in SUPPORTED_DATABASES: # pragma: no cover
419 raise ValueError(f'Unknown database: {database}')
421 return database
424def tests_reldir(*, for_async: bool) -> str:
425 return str((TESTS_DIR if for_async else SYNC_TESTS_DIR).relative_to(DATABASES_DIR))
428def tests_relpath(path: str, *, for_async: bool) -> str:
429 tests_dir = TESTS_DIR if for_async else SYNC_TESTS_DIR
430 return str((tests_dir / path).relative_to(DATABASES_DIR))
433def title(text: str) -> str:
434 # TODO: improve formatting
435 dashes = '-' * 30
436 return dashes + ' ' + click.style(text, bold=True) + ' ' + dashes
439def entrypoint(session: nox.Session) -> None:
440 """Wrapper over `cli()` that sets a `session` context variable for easier usage."""
442 def wrapper() -> None:
443 session_ctx.set(session)
444 cli(session.posargs)
446 # copy the current context so that the session object is not leaked
447 ctx = copy_context()
448 return ctx.run(wrapper)