Coverage for tests/utils.py: 100%

139 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from __future__ import annotations 

2 

3import os 

4import sys 

5import uuid 

6import inspect 

7import textwrap 

8import contextlib 

9import subprocess 

10from typing import ( 

11 TYPE_CHECKING, 

12 Any, 

13 List, 

14 Tuple, 

15 Union, 

16 Mapping, 

17 Callable, 

18 Iterator, 

19 Optional, 

20 cast, 

21) 

22from pathlib import Path 

23from typing_extensions import override 

24 

25import click 

26import pytest 

27from click.testing import Result, CliRunner 

28 

29from prisma import _config 

30from lib.utils import escape_path 

31from prisma.cli import main 

32from prisma._proxy import LazyProxy 

33from prisma._types import FuncType 

34from prisma.binaries import platform 

35from prisma.generator.utils import copy_tree 

36from prisma.generator.generator import BASE_PACKAGE_DIR 

37 

38if TYPE_CHECKING: 

39 from _pytest.pytester import Pytester, RunResult 

40 from _pytest.monkeypatch import MonkeyPatch 

41 

42 

43CapturedArgs = Tuple[Tuple[object, ...], Mapping[str, object]] 

44 

45 

46# as we are generating new modules we need to clear them from 

47# the module cache so that python actually picks them up 

48# when we import them again, however we also have to ignore 

49# any prisma.generator modules as we rely on the import caching 

50# mechanism for loading partial model types 

51IMPORT_RELOADER = """ 

52import sys 

53for name in sys.modules.copy(): 

54 if 'prisma' in name and 'generator' not in name: 

55 sys.modules.pop(name, None) 

56""" 

57 

58DEFAULT_GENERATOR = """ 

59generator db {{ 

60 provider = "coverage run -m prisma" 

61 output = "{output}" 

62 {options} 

63}} 

64 

65""" 

66 

67SCHEMA_HEADER = ( 

68 """ 

69datasource db {{ 

70 provider = "sqlite" 

71 url = "file:dev.db" 

72}} 

73 

74""" 

75 + DEFAULT_GENERATOR 

76) 

77 

78DEFAULT_SCHEMA = ( 

79 SCHEMA_HEADER 

80 + """ 

81model User {{ 

82 id String @id @default(cuid()) 

83 created_at DateTime @default(now()) 

84 updated_at DateTime @updatedAt 

85 name String 

86}} 

87""" 

88) 

89 

90 

91class Runner: 

92 def __init__(self, patcher: 'MonkeyPatch') -> None: 

93 self._runner = CliRunner() 

94 self._patcher = patcher 

95 self.default_cli: Optional[click.Command] = None 

96 self.patch_subprocess() 

97 

98 def invoke( 

99 self, 

100 args: Optional[List[str]] = None, 

101 cli: Optional[click.Command] = None, 

102 **kwargs: Any, 

103 ) -> Result: 

104 default_args: Optional[List[str]] = None 

105 

106 if cli is not None: 

107 default_args = args 

108 elif self.default_cli is not None: 

109 cli = self.default_cli 

110 default_args = args 

111 else: 

112 

113 def _cli() -> None: 

114 if args is not None: # pragma: no branch 

115 # fake invocation context 

116 args.insert(0, 'prisma') 

117 

118 main(args, use_handler=False, do_cleanup=False) 

119 

120 cli = click.command()(_cli) 

121 

122 # we don't pass any args to click as we need to parse them ourselves 

123 default_args = [] 

124 

125 return self._runner.invoke(cli, default_args, **kwargs) 

126 

127 def patch_subprocess(self) -> None: 

128 """As we can't pass a fd from something like io.TextIO to a subprocess 

129 we need to override the subprocess.run method to pipe the output and then 

130 print the output ourselves so that it can be captured by anything higher in 

131 call stack. 

132 """ 

133 

134 def _patched_subprocess_run(*args: Any, **kwargs: Any) -> 'subprocess.CompletedProcess[str]': 

135 kwargs['stdout'] = subprocess.PIPE 

136 kwargs['stderr'] = subprocess.PIPE 

137 kwargs['encoding'] = sys.getdefaultencoding() 

138 

139 process = old_subprocess_run(*args, **kwargs) 

140 

141 assert isinstance(process.stdout, str) 

142 

143 print(process.stdout) 

144 print(process.stderr, file=sys.stderr) 

145 return process 

146 

147 old_subprocess_run = subprocess.run 

148 self._patcher.setattr(subprocess, 'run', _patched_subprocess_run, raising=True) 

149 

150 

151class Testdir: 

152 __test__ = False 

153 SCHEMA_HEADER = SCHEMA_HEADER 

154 default_schema = DEFAULT_SCHEMA 

155 default_generator = DEFAULT_GENERATOR 

156 

157 def __init__(self, pytester: Pytester) -> None: 

158 self.pytester = pytester 

159 

160 def _make_relative(self, path: Union[str, Path]) -> str: # pragma: no cover 

161 if not isinstance(path, Path): 

162 path = Path(path) 

163 

164 if not path.is_absolute(): 

165 return str(path) 

166 

167 return str(path.relative_to(self.path)) 

168 

169 def make_from_function( 

170 self, 

171 function: FuncType, 

172 ext: str = '.py', 

173 name: Optional[Union[str, Path]] = None, 

174 **env: Any, 

175 ) -> None: 

176 source = get_source_from_function(function, **env) 

177 

178 if name: 

179 self.makefile(ext, **{self._make_relative(name): source}) 

180 else: 

181 self.makefile(ext, source) 

182 

183 def copy_pkg(self, clean: bool = True) -> None: 

184 path = self.path / 'prisma' 

185 copy_tree(BASE_PACKAGE_DIR, path) 

186 

187 if clean: # pragma: no branch 

188 result = self.runpython_c('import prisma_cleanup; prisma_cleanup.cleanup()') 

189 assert result.ret == 0 

190 

191 def generate( 

192 self, schema: Optional[str] = None, options: str = '', **extra: Any 

193 ) -> 'subprocess.CompletedProcess[bytes]': 

194 path = self.make_schema(schema, options, **extra) 

195 args = [sys.executable, '-m', 'prisma', 'generate', f'--schema={path}'] 

196 proc = subprocess.run( 

197 args, 

198 env=os.environ, 

199 stdout=subprocess.PIPE, 

200 stderr=subprocess.STDOUT, 

201 ) 

202 print(str(proc.stdout, sys.getdefaultencoding()), file=sys.stdout) 

203 if proc.returncode != 0: 

204 raise subprocess.CalledProcessError(proc.returncode, args, proc.stdout, proc.stderr) 

205 

206 return proc 

207 

208 def make_schema( 

209 self, 

210 schema: Optional[str] = None, 

211 options: str = '', 

212 output: Optional[str] = None, 

213 **extra: Any, 

214 ) -> Path: 

215 if schema is None: 

216 schema = self.default_schema 

217 

218 if output is None: # pragma: no branch 

219 output = 'prisma' 

220 

221 path = self.path.joinpath('schema.prisma') 

222 path.write_text( 

223 schema.format( 

224 output=escape_path(self.path.joinpath(output)), 

225 options=options, 

226 **extra, 

227 ) 

228 ) 

229 return path 

230 

231 def makefile(self, ext: str, *args: str, **kwargs: str) -> None: 

232 self.pytester.makefile(ext, *args, **kwargs) 

233 

234 def runpytest(self, *args: Union[str, 'os.PathLike[str]'], **kwargs: Any) -> 'RunResult': 

235 # pytest-sugar breaks result parsing 

236 return self.pytester.runpytest('-p', 'no:sugar', *args, **kwargs) 

237 

238 def runpython_c(self, command: str) -> 'RunResult': 

239 return self.pytester.runpython_c(command) 

240 

241 @contextlib.contextmanager 

242 def redirect_stdout_to_file( 

243 self, 

244 ) -> Iterator[Path]: 

245 path = self.path.joinpath(f'stdout-{uuid.uuid4()}.txt') 

246 

247 with path.open('w') as file: 

248 with contextlib.redirect_stdout(file): 

249 yield path 

250 

251 @property 

252 def path(self) -> Path: 

253 return Path(self.pytester.path) 

254 

255 @override 

256 def __repr__(self) -> str: # pragma: no cover 

257 return str(self) 

258 

259 @override 

260 def __str__(self) -> str: # pragma: no cover 

261 return f'<Testdir {self.path} >' 

262 

263 

264def get_source_from_function(function: FuncType, **env: Any) -> str: 

265 lines = inspect.getsource(function).splitlines()[1:] 

266 

267 # setup env after imports 

268 for index, line in enumerate(lines): 

269 if not line.lstrip(' ').startswith(('import', 'from')): 

270 start = index 

271 break 

272 else: 

273 start = 0 

274 

275 lines = textwrap.dedent('\n'.join(lines)).splitlines() 

276 for name, value in env.items(): 

277 if isinstance(value, str): # pragma: no branch 

278 value = f"'{value}'" 

279 

280 lines.insert(start, f'{name} = {value}') 

281 

282 return IMPORT_RELOADER + '\n'.join(lines) 

283 

284 

285@contextlib.contextmanager 

286def set_config(config: _config.Config) -> Iterator[_config.Config]: 

287 proxy = cast(LazyProxy[_config.Config], _config.config) 

288 old = proxy.__get_proxied__() 

289 

290 try: 

291 proxy.__set_proxied__(config) 

292 yield config 

293 finally: 

294 proxy.__set_proxied__(old) 

295 

296 

297def patch_method( 

298 patcher: 'MonkeyPatch', 

299 obj: object, 

300 attr: str, 

301 callback: Optional[Callable[..., Any]] = None, 

302) -> Callable[[], Optional[CapturedArgs]]: 

303 """Helper for patching functions that are incompatible with MonkeyPatch.setattr 

304 

305 e.g. __init__ methods 

306 """ 

307 # work around for pyright: https://github.com/microsoft/pyright/issues/2757 

308 captured = cast(Optional[CapturedArgs], None) 

309 

310 def wrapper(*args: Any, **kwargs: Any) -> None: 

311 nonlocal captured 

312 captured = (args[1:], kwargs) 

313 

314 if callback is not None: 

315 callback(real_meth, *args, **kwargs) 

316 

317 real_meth = getattr(obj, attr) 

318 patcher.setattr(obj, attr, wrapper, raising=True) 

319 return lambda: captured 

320 

321 

322skipif_windows = pytest.mark.skipif(platform.name() == 'windows', reason='Test is disabled on windows')