Coverage for src/prisma/generator/generator.py: 94%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1import os 

2import sys 

3import json 

4import shutil 

5import logging 

6import traceback 

7from abc import ABC, abstractmethod 

8from typing import Any, Dict, List, Type, Generic, Optional, cast 

9from pathlib import Path 

10from contextvars import ContextVar 

11from typing_extensions import override 

12 

13from jinja2 import Environment, StrictUndefined, FileSystemLoader 

14from pydantic import BaseModel, ValidationError 

15 

16from . import jsonrpc 

17from .. import __version__ 

18from .types import PartialModel 

19from .utils import ( 

20 copy_tree, 

21 is_same_path, 

22 resolve_template_path, 

23) 

24from ..utils import DEBUG, DEBUG_GENERATOR 

25from .errors import PartialTypeGeneratorError 

26from .models import PythonData, DefaultData 

27from .._types import BaseModelT, InheritsGeneric, get_args 

28from .filters import quote 

29from .jsonrpc import Manifest 

30from .._compat import model_json, model_parse, cached_property 

31 

32__all__ = ( 

33 'BASE_PACKAGE_DIR', 

34 'GenericGenerator', 

35 'BaseGenerator', 

36 'Generator', 

37 'render_template', 

38 'cleanup_templates', 

39 'partial_models_ctx', 

40) 

41 

42log: logging.Logger = logging.getLogger(__name__) 

43BASE_PACKAGE_DIR = Path(__file__).parent.parent 

44GENERIC_GENERATOR_NAME = 'prisma.generator.generator.GenericGenerator' 

45 

46# set of templates that should be rendered after every other template 

47DEFERRED_TEMPLATES = {'partials.py.jinja'} 

48 

49DEFAULT_ENV = Environment( 

50 trim_blocks=True, 

51 lstrip_blocks=True, 

52 loader=FileSystemLoader(Path(__file__).parent / 'templates'), 

53 undefined=StrictUndefined, 

54) 

55 

56# the type: ignore is required because Jinja2 filters are not typed 

57# and Pyright infers the type from the default builtin filters which 

58# results in an overly restrictive type 

59DEFAULT_ENV.filters['quote'] = quote # pyright: ignore 

60 

61partial_models_ctx: ContextVar[List[PartialModel]] = ContextVar('partial_models_ctx', default=[]) 

62 

63 

64class GenericGenerator(ABC, Generic[BaseModelT]): 

65 @abstractmethod 

66 def get_manifest(self) -> Manifest: 

67 """Get the metadata for this generator 

68 

69 This is used by prisma to display the post-generate message e.g. 

70 

71 ✔ Generated Prisma Client Python to ./.venv/lib/python3.9/site-packages/prisma 

72 """ 

73 ... 

74 

75 @abstractmethod 

76 def generate(self, data: BaseModelT) -> None: ... 

77 

78 @classmethod 

79 def invoke(cls) -> None: 

80 """Shorthand for calling BaseGenerator().run()""" 

81 generator = cls() 

82 generator.run() 

83 

84 def run(self) -> None: 

85 """Run the generation loop 

86 

87 This can only be called from a prisma generation, e.g. 

88 

89 ```prisma 

90 generator client { 

91 provider = "python generator.py" 

92 } 

93 ``` 

94 """ 

95 if not os.environ.get('PRISMA_GENERATOR_INVOCATION'): 

96 raise RuntimeError('Attempted to invoke a generator outside of Prisma generation') 

97 

98 request = None 

99 try: 

100 while True: 

101 line = jsonrpc.readline() 

102 if line is None: 

103 log.debug('Prisma invocation ending') 

104 break 

105 

106 request = jsonrpc.parse(line) 

107 self._on_request(request) 

108 except Exception as exc: 

109 if request is None: 109 ↛ 110line 109 didn't jump to line 110, because the condition on line 109 was never true

110 raise exc from None 

111 

112 # We don't care about being overly verbose or printing potentially redundant data here 

113 if DEBUG or DEBUG_GENERATOR: 113 ↛ 118line 113 didn't jump to line 118, because the condition on line 113 was never false

114 traceback.print_exc() 

115 

116 # Do not include the full stack trace for pydantic validation errors as they are typically 

117 # the fault of the user. 

118 if isinstance(exc, ValidationError): 

119 message = str(exc) 

120 elif isinstance(exc, PartialTypeGeneratorError): 120 ↛ 122line 120 didn't jump to line 122

121 # TODO: remove our internal frame from this stack trace 

122 message = ( 

123 'An exception ocurred while running the partial type generator\n' + traceback.format_exc().strip() 

124 ) 

125 else: 

126 message = traceback.format_exc().strip() 

127 

128 response = jsonrpc.ErrorResponse( 

129 id=request.id, 

130 error={ 

131 # code copied from https://github.com/prisma/prisma/blob/main/packages/generator-helper/src/generatorHandler.ts 

132 'code': -32000, 

133 'message': message, 

134 'data': {}, 

135 }, 

136 ) 

137 jsonrpc.reply(response) 

138 

139 def _on_request(self, request: jsonrpc.Request) -> None: 

140 response = None 

141 if request.method == 'getManifest': 

142 response = jsonrpc.SuccessResponse( 

143 id=request.id, 

144 result=dict( 

145 manifest=self.get_manifest(), 

146 ), 

147 ) 

148 elif request.method == 'generate': 

149 if request.params is None: # pragma: no cover 

150 raise RuntimeError('Prisma JSONRPC did not send data to generate.') 

151 

152 if DEBUG_GENERATOR: 

153 _write_debug_data('params', json.dumps(request.params, indent=2)) 

154 

155 data = model_parse(self.data_class, request.params) 

156 

157 if DEBUG_GENERATOR: 

158 _write_debug_data('data', model_json(data, indent=2)) 

159 

160 self.generate(data) 

161 response = jsonrpc.SuccessResponse(id=request.id, result=None) 

162 else: # pragma: no cover 

163 raise RuntimeError(f'JSON RPC received unexpected method: {request.method}') 

164 

165 jsonrpc.reply(response) 

166 

167 @cached_property 

168 def data_class(self) -> Type[BaseModelT]: 

169 """Return the BaseModel used to parse the Prisma DMMF""" 

170 

171 # we need to cast to object as otherwise pyright correctly marks the code as unreachable, 

172 # this is because __orig_bases__ is not present in the typeshed stubs as it is 

173 # intended to be for internal use only, however I could not find a method 

174 # for resolving generic TypeVars for inherited subclasses without using it. 

175 # please create an issue or pull request if you know of a solution. 

176 cls = cast(object, self.__class__) 

177 if not isinstance(cls, InheritsGeneric): 177 ↛ 178line 177 didn't jump to line 178, because the condition on line 177 was never true

178 raise RuntimeError('Could not resolve generic type arguments.') 

179 

180 typ: Optional[Any] = None 

181 for base in cls.__orig_bases__: 

182 if base.__origin__ == GenericGenerator: 182 ↛ 181line 182 didn't jump to line 181, because the condition on line 182 was never false

183 typ = base 

184 break 

185 

186 if typ is None: # pragma: no cover 

187 raise RuntimeError( 

188 'Could not find the GenericGenerator type;\n' 

189 'This should never happen;\n' 

190 f'Does {cls} inherit from {GenericGenerator} ?' 

191 ) 

192 

193 args = get_args(typ) 

194 if not args: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true

195 raise RuntimeError(f'Could not resolve generic arguments from type: {typ}') 

196 

197 model = args[0] 

198 if not issubclass(model, BaseModel): 

199 raise TypeError( 

200 f'Expected first generic type argument argument to be a subclass of {BaseModel} ' 

201 f'but got {model} instead.' 

202 ) 

203 

204 # we know the type we have resolved is the same as the first generic argument 

205 # passed to GenericGenerator, safe to cast 

206 return cast(Type[BaseModelT], model) 

207 

208 

209class BaseGenerator(GenericGenerator[DefaultData]): 

210 pass 

211 

212 

213class Generator(GenericGenerator[PythonData]): 

214 @override 

215 def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: 

216 raise TypeError(f'{Generator} cannot be subclassed, maybe you meant {BaseGenerator}?') 

217 

218 @override 

219 def get_manifest(self) -> Manifest: 

220 return Manifest( 

221 name=f'Prisma Client Python (v{__version__})', 

222 default_output=BASE_PACKAGE_DIR, 

223 requires_engines=[ 

224 'queryEngine', 

225 ], 

226 ) 

227 

228 @override 

229 def generate(self, data: PythonData) -> None: 

230 config = data.generator.config 

231 rootdir = Path(data.generator.output.value) 

232 if not rootdir.exists(): 

233 rootdir.mkdir(parents=True, exist_ok=True) 

234 

235 if not is_same_path(BASE_PACKAGE_DIR, rootdir): 

236 copy_tree(BASE_PACKAGE_DIR, rootdir) 

237 

238 # copy the Prisma Schema file used to generate the client to the 

239 # package so we can use it to instantiate the query engine 

240 packaged_schema = rootdir / 'schema.prisma' 

241 if not is_same_path(data.schema_path, packaged_schema): 

242 shutil.copy(data.schema_path, packaged_schema) 

243 

244 params = data.to_params() 

245 

246 try: 

247 for name in DEFAULT_ENV.list_templates(): 

248 if not name.endswith('.py.jinja') or name.startswith('_') or name in DEFERRED_TEMPLATES: 

249 continue 

250 

251 render_template(rootdir, name, params) 

252 

253 if config.partial_type_generator: 

254 log.debug('Generating partial types') 

255 config.partial_type_generator.run() 

256 

257 params['partial_models'] = partial_models_ctx.get() 

258 for name in DEFERRED_TEMPLATES: 

259 render_template(rootdir, name, params) 

260 except: 

261 cleanup_templates(rootdir, env=DEFAULT_ENV) 

262 raise 

263 

264 log.debug('Finished generating Prisma Client Python') 

265 

266 

267def cleanup_templates(rootdir: Path, *, env: Optional[Environment] = None) -> None: 

268 """Revert module to pre-generation state""" 

269 if env is None: 

270 env = DEFAULT_ENV 

271 

272 for name in env.list_templates(): 

273 file = resolve_template_path(rootdir=rootdir, name=name) 

274 if file.exists(): 

275 log.debug('Removing rendered template at %s', file) 

276 file.unlink() 

277 

278 

279def render_template( 

280 rootdir: Path, 

281 name: str, 

282 params: Dict[str, Any], 

283 *, 

284 env: Optional[Environment] = None, 

285) -> None: 

286 if env is None: 

287 env = DEFAULT_ENV 

288 

289 template = env.get_template(name) 

290 output = template.render(**params) 

291 

292 file = resolve_template_path(rootdir=rootdir, name=name) 

293 if not file.parent.exists(): 293 ↛ 294line 293 didn't jump to line 294, because the condition on line 293 was never true

294 file.parent.mkdir(parents=True, exist_ok=True) 

295 

296 file.write_bytes(output.encode(sys.getdefaultencoding())) 

297 log.debug('Rendered template to %s', file.absolute()) 

298 

299 

300def _write_debug_data(name: str, output: str) -> None: 

301 path = Path(__file__).parent.joinpath(f'debug-{name}.json') 

302 

303 with path.open('w') as file: 

304 file.write(output) 

305 

306 log.debug('Wrote generator %s to %s', name, path.absolute())