Coverage for src/prisma/generator/generator.py: 94%

152 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1import os 

2import sys 

3import json 

4import logging 

5import traceback 

6from abc import ABC, abstractmethod 

7from typing import Any, Dict, List, Type, Generic, Optional, cast 

8from pathlib import Path 

9from contextvars import ContextVar 

10from typing_extensions import override 

11 

12from jinja2 import Environment, StrictUndefined, FileSystemLoader 

13from pydantic import BaseModel, ValidationError 

14 

15from . import jsonrpc 

16from .. import __version__ 

17from .types import PartialModel 

18from .utils import ( 

19 copy_tree, 

20 is_same_path, 

21 resolve_template_path, 

22) 

23from ..utils import DEBUG, DEBUG_GENERATOR 

24from .errors import PartialTypeGeneratorError 

25from .models import PythonData, DefaultData 

26from .._types import BaseModelT, InheritsGeneric, get_args 

27from .filters import quote 

28from .jsonrpc import Manifest 

29from .._compat import model_json, model_parse, cached_property 

30 

31__all__ = ( 

32 'BASE_PACKAGE_DIR', 

33 'GenericGenerator', 

34 'BaseGenerator', 

35 'Generator', 

36 'render_template', 

37 'cleanup_templates', 

38 'partial_models_ctx', 

39) 

40 

41log: logging.Logger = logging.getLogger(__name__) 

42BASE_PACKAGE_DIR = Path(__file__).parent.parent 

43GENERIC_GENERATOR_NAME = 'prisma.generator.generator.GenericGenerator' 

44 

45# set of templates that should be rendered after every other template 

46DEFERRED_TEMPLATES = {'partials.py.jinja'} 

47 

48DEFAULT_ENV = Environment( 

49 trim_blocks=True, 

50 lstrip_blocks=True, 

51 loader=FileSystemLoader(Path(__file__).parent / 'templates'), 

52 undefined=StrictUndefined, 

53) 

54 

55# the type: ignore is required because Jinja2 filters are not typed 

56# and Pyright infers the type from the default builtin filters which 

57# results in an overly restrictive type 

58DEFAULT_ENV.filters['quote'] = quote # pyright: ignore 

59 

60partial_models_ctx: ContextVar[List[PartialModel]] = ContextVar('partial_models_ctx', default=[]) 

61 

62 

63class GenericGenerator(ABC, Generic[BaseModelT]): 

64 @abstractmethod 

65 def get_manifest(self) -> Manifest: 

66 """Get the metadata for this generator 

67 

68 This is used by prisma to display the post-generate message e.g. 

69 

70 ✔ Generated Prisma Client Python to ./.venv/lib/python3.9/site-packages/prisma 

71 """ 

72 ... 

73 

74 @abstractmethod 

75 def generate(self, data: BaseModelT) -> None: ... 

76 

77 @classmethod 

78 def invoke(cls) -> None: 

79 """Shorthand for calling BaseGenerator().run()""" 

80 generator = cls() 

81 generator.run() 

82 

83 def run(self) -> None: 

84 """Run the generation loop 

85 

86 This can only be called from a prisma generation, e.g. 

87 

88 ```prisma 

89 generator client { 

90 provider = "python generator.py" 

91 } 

92 ``` 

93 """ 

94 if not os.environ.get('PRISMA_GENERATOR_INVOCATION'): 

95 raise RuntimeError('Attempted to invoke a generator outside of Prisma generation') 

96 

97 request = None 

98 try: 

99 while True: 

100 line = jsonrpc.readline() 

101 if line is None: 

102 log.debug('Prisma invocation ending') 

103 break 

104 

105 request = jsonrpc.parse(line) 

106 self._on_request(request) 

107 except Exception as exc: 

108 if request is None: 108 ↛ 109line 108 didn't jump to line 109, because the condition on line 108 was never true

109 raise exc from None 

110 

111 # We don't care about being overly verbose or printing potentially redundant data here 

112 if DEBUG or DEBUG_GENERATOR: 112 ↛ 117line 112 didn't jump to line 117, because the condition on line 112 was never false

113 traceback.print_exc() 

114 

115 # Do not include the full stack trace for pydantic validation errors as they are typically 

116 # the fault of the user. 

117 if isinstance(exc, ValidationError): 

118 message = str(exc) 

119 elif isinstance(exc, PartialTypeGeneratorError): 119 ↛ 121line 119 didn't jump to line 121

120 # TODO: remove our internal frame from this stack trace 

121 message = ( 

122 'An exception ocurred while running the partial type generator\n' + traceback.format_exc().strip() 

123 ) 

124 else: 

125 message = traceback.format_exc().strip() 

126 

127 response = jsonrpc.ErrorResponse( 

128 id=request.id, 

129 error={ 

130 # code copied from https://github.com/prisma/prisma/blob/main/packages/generator-helper/src/generatorHandler.ts 

131 'code': -32000, 

132 'message': message, 

133 'data': {}, 

134 }, 

135 ) 

136 jsonrpc.reply(response) 

137 

138 def _on_request(self, request: jsonrpc.Request) -> None: 

139 response = None 

140 if request.method == 'getManifest': 

141 response = jsonrpc.SuccessResponse( 

142 id=request.id, 

143 result=dict( 

144 manifest=self.get_manifest(), 

145 ), 

146 ) 

147 elif request.method == 'generate': 

148 if request.params is None: # pragma: no cover 

149 raise RuntimeError('Prisma JSONRPC did not send data to generate.') 

150 

151 if DEBUG_GENERATOR: 

152 _write_debug_data('params', json.dumps(request.params, indent=2)) 

153 

154 data = model_parse(self.data_class, request.params) 

155 

156 if DEBUG_GENERATOR: 

157 _write_debug_data('data', model_json(data, indent=2)) 

158 

159 self.generate(data) 

160 response = jsonrpc.SuccessResponse(id=request.id, result=None) 

161 else: # pragma: no cover 

162 raise RuntimeError(f'JSON RPC received unexpected method: {request.method}') 

163 

164 jsonrpc.reply(response) 

165 

166 @cached_property 

167 def data_class(self) -> Type[BaseModelT]: 

168 """Return the BaseModel used to parse the Prisma DMMF""" 

169 

170 # we need to cast to object as otherwise pyright correctly marks the code as unreachable, 

171 # this is because __orig_bases__ is not present in the typeshed stubs as it is 

172 # intended to be for internal use only, however I could not find a method 

173 # for resolving generic TypeVars for inherited subclasses without using it. 

174 # please create an issue or pull request if you know of a solution. 

175 cls = cast(object, self.__class__) 

176 if not isinstance(cls, InheritsGeneric): 176 ↛ 177line 176 didn't jump to line 177, because the condition on line 176 was never true

177 raise RuntimeError('Could not resolve generic type arguments.') 

178 

179 typ: Optional[Any] = None 

180 for base in cls.__orig_bases__: 

181 if base.__origin__ == GenericGenerator: 181 ↛ 180line 181 didn't jump to line 180, because the condition on line 181 was never false

182 typ = base 

183 break 

184 

185 if typ is None: # pragma: no cover 

186 raise RuntimeError( 

187 'Could not find the GenericGenerator type;\n' 

188 'This should never happen;\n' 

189 f'Does {cls} inherit from {GenericGenerator} ?' 

190 ) 

191 

192 args = get_args(typ) 

193 if not args: 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 raise RuntimeError(f'Could not resolve generic arguments from type: {typ}') 

195 

196 model = args[0] 

197 if not issubclass(model, BaseModel): 

198 raise TypeError( 

199 f'Expected first generic type argument argument to be a subclass of {BaseModel} ' 

200 f'but got {model} instead.' 

201 ) 

202 

203 # we know the type we have resolved is the same as the first generic argument 

204 # passed to GenericGenerator, safe to cast 

205 return cast(Type[BaseModelT], model) 

206 

207 

208class BaseGenerator(GenericGenerator[DefaultData]): 

209 pass 

210 

211 

212class Generator(GenericGenerator[PythonData]): 

213 @override 

214 def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: 

215 raise TypeError(f'{Generator} cannot be subclassed, maybe you meant {BaseGenerator}?') 

216 

217 @override 

218 def get_manifest(self) -> Manifest: 

219 return Manifest( 

220 name=f'Prisma Client Python (v{__version__})', 

221 default_output=BASE_PACKAGE_DIR, 

222 requires_engines=[ 

223 'queryEngine', 

224 ], 

225 ) 

226 

227 @override 

228 def generate(self, data: PythonData) -> None: 

229 config = data.generator.config 

230 rootdir = Path(data.generator.output.value) 

231 if not rootdir.exists(): 

232 rootdir.mkdir(parents=True, exist_ok=True) 

233 

234 if not is_same_path(BASE_PACKAGE_DIR, rootdir): 

235 copy_tree(BASE_PACKAGE_DIR, rootdir) 

236 

237 # copy the Prisma Schema file used to generate the client to the 

238 # package so we can use it to instantiate the query engine 

239 packaged_schema = rootdir / 'schema.prisma' 

240 if not is_same_path(data.schema_path, packaged_schema): 

241 packaged_schema.write_text(data.datamodel) 

242 

243 params = data.to_params() 

244 

245 try: 

246 for name in DEFAULT_ENV.list_templates(): 

247 if not name.endswith('.py.jinja') or name.startswith('_') or name in DEFERRED_TEMPLATES: 

248 continue 

249 

250 render_template(rootdir, name, params) 

251 

252 if config.partial_type_generator: 

253 log.debug('Generating partial types') 

254 config.partial_type_generator.run() 

255 

256 params['partial_models'] = partial_models_ctx.get() 

257 for name in DEFERRED_TEMPLATES: 

258 render_template(rootdir, name, params) 

259 except: 

260 cleanup_templates(rootdir, env=DEFAULT_ENV) 

261 raise 

262 

263 log.debug('Finished generating Prisma Client Python') 

264 

265 

266def cleanup_templates(rootdir: Path, *, env: Optional[Environment] = None) -> None: 

267 """Revert module to pre-generation state""" 

268 if env is None: 

269 env = DEFAULT_ENV 

270 

271 for name in env.list_templates(): 

272 file = resolve_template_path(rootdir=rootdir, name=name) 

273 if file.exists(): 

274 log.debug('Removing rendered template at %s', file) 

275 file.unlink() 

276 

277 

278def render_template( 

279 rootdir: Path, 

280 name: str, 

281 params: Dict[str, Any], 

282 *, 

283 env: Optional[Environment] = None, 

284) -> None: 

285 if env is None: 

286 env = DEFAULT_ENV 

287 

288 template = env.get_template(name) 

289 output = template.render(**params) 

290 

291 file = resolve_template_path(rootdir=rootdir, name=name) 

292 if not file.parent.exists(): 292 ↛ 293line 292 didn't jump to line 293, because the condition on line 292 was never true

293 file.parent.mkdir(parents=True, exist_ok=True) 

294 

295 file.write_bytes(output.encode(sys.getdefaultencoding())) 

296 log.debug('Rendered template to %s', file.absolute()) 

297 

298 

299def _write_debug_data(name: str, output: str) -> None: 

300 path = Path(__file__).parent.joinpath(f'debug-{name}.json') 

301 

302 with path.open('w') as file: 

303 file.write(output) 

304 

305 log.debug('Wrote generator %s to %s', name, path.absolute())