Coverage for tests/utils.py: 100%
139 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import os
4import sys
5import uuid
6import inspect
7import textwrap
8import contextlib
9import subprocess
10from typing import (
11 TYPE_CHECKING,
12 Any,
13 List,
14 Tuple,
15 Union,
16 Mapping,
17 Callable,
18 Iterator,
19 Optional,
20 cast,
21)
22from pathlib import Path
23from typing_extensions import override
25import click
26import pytest
27from click.testing import Result, CliRunner
29from prisma import _config
30from lib.utils import escape_path
31from prisma.cli import main
32from prisma._proxy import LazyProxy
33from prisma._types import FuncType
34from prisma.binaries import platform
35from prisma.generator.utils import copy_tree
36from prisma.generator.generator import BASE_PACKAGE_DIR
38if TYPE_CHECKING:
39 from _pytest.pytester import Pytester, RunResult
40 from _pytest.monkeypatch import MonkeyPatch
43CapturedArgs = Tuple[Tuple[object, ...], Mapping[str, object]]
46# as we are generating new modules we need to clear them from
47# the module cache so that python actually picks them up
48# when we import them again, however we also have to ignore
49# any prisma.generator modules as we rely on the import caching
50# mechanism for loading partial model types
51IMPORT_RELOADER = """
52import sys
53for name in sys.modules.copy():
54 if 'prisma' in name and 'generator' not in name:
55 sys.modules.pop(name, None)
56"""
58DEFAULT_GENERATOR = """
59generator db {{
60 provider = "coverage run -m prisma"
61 output = "{output}"
62 {options}
63}}
65"""
67SCHEMA_HEADER = (
68 """
69datasource db {{
70 provider = "sqlite"
71 url = "file:dev.db"
72}}
74"""
75 + DEFAULT_GENERATOR
76)
78DEFAULT_SCHEMA = (
79 SCHEMA_HEADER
80 + """
81model User {{
82 id String @id @default(cuid())
83 created_at DateTime @default(now())
84 updated_at DateTime @updatedAt
85 name String
86}}
87"""
88)
91class Runner:
92 def __init__(self, patcher: 'MonkeyPatch') -> None:
93 self._runner = CliRunner()
94 self._patcher = patcher
95 self.default_cli: Optional[click.Command] = None
96 self.patch_subprocess()
98 def invoke(
99 self,
100 args: Optional[List[str]] = None,
101 cli: Optional[click.Command] = None,
102 **kwargs: Any,
103 ) -> Result:
104 default_args: Optional[List[str]] = None
106 if cli is not None:
107 default_args = args
108 elif self.default_cli is not None:
109 cli = self.default_cli
110 default_args = args
111 else:
113 def _cli() -> None:
114 if args is not None: # pragma: no branch
115 # fake invocation context
116 args.insert(0, 'prisma')
118 main(args, use_handler=False, do_cleanup=False)
120 cli = click.command()(_cli)
122 # we don't pass any args to click as we need to parse them ourselves
123 default_args = []
125 return self._runner.invoke(cli, default_args, **kwargs)
127 def patch_subprocess(self) -> None:
128 """As we can't pass a fd from something like io.TextIO to a subprocess
129 we need to override the subprocess.run method to pipe the output and then
130 print the output ourselves so that it can be captured by anything higher in
131 call stack.
132 """
134 def _patched_subprocess_run(*args: Any, **kwargs: Any) -> 'subprocess.CompletedProcess[str]':
135 kwargs['stdout'] = subprocess.PIPE
136 kwargs['stderr'] = subprocess.PIPE
137 kwargs['encoding'] = sys.getdefaultencoding()
139 process = old_subprocess_run(*args, **kwargs)
141 assert isinstance(process.stdout, str)
143 print(process.stdout)
144 print(process.stderr, file=sys.stderr)
145 return process
147 old_subprocess_run = subprocess.run
148 self._patcher.setattr(subprocess, 'run', _patched_subprocess_run, raising=True)
151class Testdir:
152 __test__ = False
153 SCHEMA_HEADER = SCHEMA_HEADER
154 default_schema = DEFAULT_SCHEMA
155 default_generator = DEFAULT_GENERATOR
157 def __init__(self, pytester: Pytester) -> None:
158 self.pytester = pytester
160 def _make_relative(self, path: Union[str, Path]) -> str: # pragma: no cover
161 if not isinstance(path, Path):
162 path = Path(path)
164 if not path.is_absolute():
165 return str(path)
167 return str(path.relative_to(self.path))
169 def make_from_function(
170 self,
171 function: FuncType,
172 ext: str = '.py',
173 name: Optional[Union[str, Path]] = None,
174 **env: Any,
175 ) -> None:
176 source = get_source_from_function(function, **env)
178 if name:
179 self.makefile(ext, **{self._make_relative(name): source})
180 else:
181 self.makefile(ext, source)
183 def copy_pkg(self, clean: bool = True) -> None:
184 path = self.path / 'prisma'
185 copy_tree(BASE_PACKAGE_DIR, path)
187 if clean: # pragma: no branch
188 result = self.runpython_c('import prisma_cleanup; prisma_cleanup.cleanup()')
189 assert result.ret == 0
191 def generate(
192 self, schema: Optional[str] = None, options: str = '', **extra: Any
193 ) -> 'subprocess.CompletedProcess[bytes]':
194 path = self.make_schema(schema, options, **extra)
195 args = [sys.executable, '-m', 'prisma', 'generate', f'--schema={path}']
196 proc = subprocess.run(
197 args,
198 env=os.environ,
199 stdout=subprocess.PIPE,
200 stderr=subprocess.STDOUT,
201 )
202 print(str(proc.stdout, sys.getdefaultencoding()), file=sys.stdout)
203 if proc.returncode != 0:
204 raise subprocess.CalledProcessError(proc.returncode, args, proc.stdout, proc.stderr)
206 return proc
208 def make_schema(
209 self,
210 schema: Optional[str] = None,
211 options: str = '',
212 output: Optional[str] = None,
213 **extra: Any,
214 ) -> Path:
215 if schema is None:
216 schema = self.default_schema
218 if output is None: # pragma: no branch
219 output = 'prisma'
221 path = self.path.joinpath('schema.prisma')
222 path.write_text(
223 schema.format(
224 output=escape_path(self.path.joinpath(output)),
225 options=options,
226 **extra,
227 )
228 )
229 return path
231 def makefile(self, ext: str, *args: str, **kwargs: str) -> None:
232 self.pytester.makefile(ext, *args, **kwargs)
234 def runpytest(self, *args: Union[str, 'os.PathLike[str]'], **kwargs: Any) -> 'RunResult':
235 # pytest-sugar breaks result parsing
236 return self.pytester.runpytest('-p', 'no:sugar', *args, **kwargs)
238 def runpython_c(self, command: str) -> 'RunResult':
239 return self.pytester.runpython_c(command)
241 @contextlib.contextmanager
242 def redirect_stdout_to_file(
243 self,
244 ) -> Iterator[Path]:
245 path = self.path.joinpath(f'stdout-{uuid.uuid4()}.txt')
247 with path.open('w') as file:
248 with contextlib.redirect_stdout(file):
249 yield path
251 @property
252 def path(self) -> Path:
253 return Path(self.pytester.path)
255 @override
256 def __repr__(self) -> str: # pragma: no cover
257 return str(self)
259 @override
260 def __str__(self) -> str: # pragma: no cover
261 return f'<Testdir {self.path} >'
264def get_source_from_function(function: FuncType, **env: Any) -> str:
265 lines = inspect.getsource(function).splitlines()[1:]
267 # setup env after imports
268 for index, line in enumerate(lines):
269 if not line.lstrip(' ').startswith(('import', 'from')):
270 start = index
271 break
272 else:
273 start = 0
275 lines = textwrap.dedent('\n'.join(lines)).splitlines()
276 for name, value in env.items():
277 if isinstance(value, str): # pragma: no branch
278 value = f"'{value}'"
280 lines.insert(start, f'{name} = {value}')
282 return IMPORT_RELOADER + '\n'.join(lines)
285@contextlib.contextmanager
286def set_config(config: _config.Config) -> Iterator[_config.Config]:
287 proxy = cast(LazyProxy[_config.Config], _config.config)
288 old = proxy.__get_proxied__()
290 try:
291 proxy.__set_proxied__(config)
292 yield config
293 finally:
294 proxy.__set_proxied__(old)
297def patch_method(
298 patcher: 'MonkeyPatch',
299 obj: object,
300 attr: str,
301 callback: Optional[Callable[..., Any]] = None,
302) -> Callable[[], Optional[CapturedArgs]]:
303 """Helper for patching functions that are incompatible with MonkeyPatch.setattr
305 e.g. __init__ methods
306 """
307 # work around for pyright: https://github.com/microsoft/pyright/issues/2757
308 captured = cast(Optional[CapturedArgs], None)
310 def wrapper(*args: Any, **kwargs: Any) -> None:
311 nonlocal captured
312 captured = (args[1:], kwargs)
314 if callback is not None:
315 callback(real_meth, *args, **kwargs)
317 real_meth = getattr(obj, attr)
318 patcher.setattr(obj, attr, wrapper, raising=True)
319 return lambda: captured
322skipif_windows = pytest.mark.skipif(platform.name() == 'windows', reason='Test is disabled on windows')