Coverage for src/prisma/generator/generator.py: 94%
153 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
1import os
2import sys
3import json
4import shutil
5import logging
6import traceback
7from abc import ABC, abstractmethod
8from typing import Any, Dict, List, Type, Generic, Optional, cast
9from pathlib import Path
10from contextvars import ContextVar
11from typing_extensions import override
13from jinja2 import Environment, StrictUndefined, FileSystemLoader
14from pydantic import BaseModel, ValidationError
16from . import jsonrpc
17from .. import __version__
18from .types import PartialModel
19from .utils import (
20 copy_tree,
21 is_same_path,
22 resolve_template_path,
23)
24from ..utils import DEBUG, DEBUG_GENERATOR
25from .errors import PartialTypeGeneratorError
26from .models import PythonData, DefaultData
27from .._types import BaseModelT, InheritsGeneric, get_args
28from .filters import quote
29from .jsonrpc import Manifest
30from .._compat import model_json, model_parse, cached_property
32__all__ = (
33 'BASE_PACKAGE_DIR',
34 'GenericGenerator',
35 'BaseGenerator',
36 'Generator',
37 'render_template',
38 'cleanup_templates',
39 'partial_models_ctx',
40)
42log: logging.Logger = logging.getLogger(__name__)
43BASE_PACKAGE_DIR = Path(__file__).parent.parent
44GENERIC_GENERATOR_NAME = 'prisma.generator.generator.GenericGenerator'
46# set of templates that should be rendered after every other template
47DEFERRED_TEMPLATES = {'partials.py.jinja'}
49DEFAULT_ENV = Environment(
50 trim_blocks=True,
51 lstrip_blocks=True,
52 loader=FileSystemLoader(Path(__file__).parent / 'templates'),
53 undefined=StrictUndefined,
54)
56# the type: ignore is required because Jinja2 filters are not typed
57# and Pyright infers the type from the default builtin filters which
58# results in an overly restrictive type
59DEFAULT_ENV.filters['quote'] = quote # pyright: ignore
61partial_models_ctx: ContextVar[List[PartialModel]] = ContextVar('partial_models_ctx', default=[])
64class GenericGenerator(ABC, Generic[BaseModelT]):
65 @abstractmethod
66 def get_manifest(self) -> Manifest:
67 """Get the metadata for this generator
69 This is used by prisma to display the post-generate message e.g.
71 ✔ Generated Prisma Client Python to ./.venv/lib/python3.9/site-packages/prisma
72 """
73 ...
75 @abstractmethod
76 def generate(self, data: BaseModelT) -> None: ...
78 @classmethod
79 def invoke(cls) -> None:
80 """Shorthand for calling BaseGenerator().run()"""
81 generator = cls()
82 generator.run()
84 def run(self) -> None:
85 """Run the generation loop
87 This can only be called from a prisma generation, e.g.
89 ```prisma
90 generator client {
91 provider = "python generator.py"
92 }
93 ```
94 """
95 if not os.environ.get('PRISMA_GENERATOR_INVOCATION'):
96 raise RuntimeError('Attempted to invoke a generator outside of Prisma generation')
98 request = None
99 try:
100 while True:
101 line = jsonrpc.readline()
102 if line is None:
103 log.debug('Prisma invocation ending')
104 break
106 request = jsonrpc.parse(line)
107 self._on_request(request)
108 except Exception as exc:
109 if request is None: 109 ↛ 110line 109 didn't jump to line 110, because the condition on line 109 was never true
110 raise exc from None
112 # We don't care about being overly verbose or printing potentially redundant data here
113 if DEBUG or DEBUG_GENERATOR: 113 ↛ 118line 113 didn't jump to line 118, because the condition on line 113 was never false
114 traceback.print_exc()
116 # Do not include the full stack trace for pydantic validation errors as they are typically
117 # the fault of the user.
118 if isinstance(exc, ValidationError):
119 message = str(exc)
120 elif isinstance(exc, PartialTypeGeneratorError): 120 ↛ 122line 120 didn't jump to line 122
121 # TODO: remove our internal frame from this stack trace
122 message = (
123 'An exception ocurred while running the partial type generator\n' + traceback.format_exc().strip()
124 )
125 else:
126 message = traceback.format_exc().strip()
128 response = jsonrpc.ErrorResponse(
129 id=request.id,
130 error={
131 # code copied from https://github.com/prisma/prisma/blob/main/packages/generator-helper/src/generatorHandler.ts
132 'code': -32000,
133 'message': message,
134 'data': {},
135 },
136 )
137 jsonrpc.reply(response)
139 def _on_request(self, request: jsonrpc.Request) -> None:
140 response = None
141 if request.method == 'getManifest':
142 response = jsonrpc.SuccessResponse(
143 id=request.id,
144 result=dict(
145 manifest=self.get_manifest(),
146 ),
147 )
148 elif request.method == 'generate':
149 if request.params is None: # pragma: no cover
150 raise RuntimeError('Prisma JSONRPC did not send data to generate.')
152 if DEBUG_GENERATOR:
153 _write_debug_data('params', json.dumps(request.params, indent=2))
155 data = model_parse(self.data_class, request.params)
157 if DEBUG_GENERATOR:
158 _write_debug_data('data', model_json(data, indent=2))
160 self.generate(data)
161 response = jsonrpc.SuccessResponse(id=request.id, result=None)
162 else: # pragma: no cover
163 raise RuntimeError(f'JSON RPC received unexpected method: {request.method}')
165 jsonrpc.reply(response)
167 @cached_property
168 def data_class(self) -> Type[BaseModelT]:
169 """Return the BaseModel used to parse the Prisma DMMF"""
171 # we need to cast to object as otherwise pyright correctly marks the code as unreachable,
172 # this is because __orig_bases__ is not present in the typeshed stubs as it is
173 # intended to be for internal use only, however I could not find a method
174 # for resolving generic TypeVars for inherited subclasses without using it.
175 # please create an issue or pull request if you know of a solution.
176 cls = cast(object, self.__class__)
177 if not isinstance(cls, InheritsGeneric): 177 ↛ 178line 177 didn't jump to line 178, because the condition on line 177 was never true
178 raise RuntimeError('Could not resolve generic type arguments.')
180 typ: Optional[Any] = None
181 for base in cls.__orig_bases__:
182 if base.__origin__ == GenericGenerator: 182 ↛ 181line 182 didn't jump to line 181, because the condition on line 182 was never false
183 typ = base
184 break
186 if typ is None: # pragma: no cover
187 raise RuntimeError(
188 'Could not find the GenericGenerator type;\n'
189 'This should never happen;\n'
190 f'Does {cls} inherit from {GenericGenerator} ?'
191 )
193 args = get_args(typ)
194 if not args: 194 ↛ 195line 194 didn't jump to line 195, because the condition on line 194 was never true
195 raise RuntimeError(f'Could not resolve generic arguments from type: {typ}')
197 model = args[0]
198 if not issubclass(model, BaseModel):
199 raise TypeError(
200 f'Expected first generic type argument argument to be a subclass of {BaseModel} '
201 f'but got {model} instead.'
202 )
204 # we know the type we have resolved is the same as the first generic argument
205 # passed to GenericGenerator, safe to cast
206 return cast(Type[BaseModelT], model)
209class BaseGenerator(GenericGenerator[DefaultData]):
210 pass
213class Generator(GenericGenerator[PythonData]):
214 @override
215 def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
216 raise TypeError(f'{Generator} cannot be subclassed, maybe you meant {BaseGenerator}?')
218 @override
219 def get_manifest(self) -> Manifest:
220 return Manifest(
221 name=f'Prisma Client Python (v{__version__})',
222 default_output=BASE_PACKAGE_DIR,
223 requires_engines=[
224 'queryEngine',
225 ],
226 )
228 @override
229 def generate(self, data: PythonData) -> None:
230 config = data.generator.config
231 rootdir = Path(data.generator.output.value)
232 if not rootdir.exists():
233 rootdir.mkdir(parents=True, exist_ok=True)
235 if not is_same_path(BASE_PACKAGE_DIR, rootdir):
236 copy_tree(BASE_PACKAGE_DIR, rootdir)
238 # copy the Prisma Schema file used to generate the client to the
239 # package so we can use it to instantiate the query engine
240 packaged_schema = rootdir / 'schema.prisma'
241 if not is_same_path(data.schema_path, packaged_schema):
242 shutil.copy(data.schema_path, packaged_schema)
244 params = data.to_params()
246 try:
247 for name in DEFAULT_ENV.list_templates():
248 if not name.endswith('.py.jinja') or name.startswith('_') or name in DEFERRED_TEMPLATES:
249 continue
251 render_template(rootdir, name, params)
253 if config.partial_type_generator:
254 log.debug('Generating partial types')
255 config.partial_type_generator.run()
257 params['partial_models'] = partial_models_ctx.get()
258 for name in DEFERRED_TEMPLATES:
259 render_template(rootdir, name, params)
260 except:
261 cleanup_templates(rootdir, env=DEFAULT_ENV)
262 raise
264 log.debug('Finished generating Prisma Client Python')
267def cleanup_templates(rootdir: Path, *, env: Optional[Environment] = None) -> None:
268 """Revert module to pre-generation state"""
269 if env is None:
270 env = DEFAULT_ENV
272 for name in env.list_templates():
273 file = resolve_template_path(rootdir=rootdir, name=name)
274 if file.exists():
275 log.debug('Removing rendered template at %s', file)
276 file.unlink()
279def render_template(
280 rootdir: Path,
281 name: str,
282 params: Dict[str, Any],
283 *,
284 env: Optional[Environment] = None,
285) -> None:
286 if env is None:
287 env = DEFAULT_ENV
289 template = env.get_template(name)
290 output = template.render(**params)
292 file = resolve_template_path(rootdir=rootdir, name=name)
293 if not file.parent.exists(): 293 ↛ 294line 293 didn't jump to line 294, because the condition on line 293 was never true
294 file.parent.mkdir(parents=True, exist_ok=True)
296 file.write_bytes(output.encode(sys.getdefaultencoding()))
297 log.debug('Rendered template to %s', file.absolute())
300def _write_debug_data(name: str, output: str) -> None:
301 path = Path(__file__).parent.joinpath(f'debug-{name}.json')
303 with path.open('w') as file:
304 file.write(output)
306 log.debug('Wrote generator %s to %s', name, path.absolute())