Coverage for src/prisma/mypy.py: 0%
192 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1import re
2import copy
3import logging
4import builtins
5import operator
6from typing import (
7 Any,
8 Dict,
9 Type as TypingType,
10 Union,
11 Callable,
12 Optional,
13 cast,
14)
15from configparser import ConfigParser
16from typing_extensions import override
18from mypy.nodes import (
19 Var,
20 Node,
21 Context,
22 IntExpr,
23 StrExpr,
24 CallExpr,
25 DictExpr,
26 NameExpr,
27 TypeInfo,
28 BytesExpr,
29 Expression,
30 SymbolTable,
31 SymbolTableNode,
32)
33from mypy.types import (
34 Type,
35 Instance,
36 NoneType,
37 UnionType,
38)
39from mypy.plugin import Plugin, MethodContext, CheckerPluginInterface
40from mypy.options import Options
41from mypy.errorcodes import ErrorCode
43# match any direct children of an actions class
44CLIENT_ACTION_CHILD = re.compile(r'prisma\.actions\.(.*)Actions\.(?P<name>(((?!\.).)*$))')
45ACTIONS = [
46 'create',
47 'find_unique',
48 'delete',
49 'update',
50 'find_first',
51 'find_many',
52 'upsert',
53 'update_many',
54 'delete_many',
55 'count',
56]
58CONFIGFILE_KEY = 'prisma-mypy'
60log: logging.Logger = logging.getLogger(__name__)
63# due to the way the mypy API is typed we unfortunately have to disable Pyright type checks
64# this is because mypy type hints are written like this: `Bogus[str]` instead of `str`
65# mypy uses internal magic to transform Bogus[T] to T which pyright cannot understand.
66# pyright: reportGeneralTypeIssues=false, reportUnnecessaryComparison=false
69def plugin(version: str) -> TypingType[Plugin]: # noqa: ARG001
70 return PrismaPlugin
73class PrismaPluginConfig:
74 __slots__ = ('warn_parsing_errors',)
75 warn_parsing_errors: bool
77 def __init__(self, options: Options) -> None:
78 if options.config_file is None: # pragma: no cover
79 return
81 plugin_config = ConfigParser()
82 plugin_config.read(options.config_file)
83 for key in self.__slots__:
84 setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=True)
85 setattr(self, key, setting)
88class PrismaPlugin(Plugin):
89 config: PrismaPluginConfig
91 def __init__(self, options: Options) -> None:
92 self.config = PrismaPluginConfig(options)
93 super().__init__(options)
95 @override
96 def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
97 match = CLIENT_ACTION_CHILD.match(fullname)
98 if not match:
99 return None
101 if match.group('name') in ACTIONS:
102 return self.handle_action_invocation
104 return None
106 def handle_action_invocation(self, ctx: MethodContext) -> Type:
107 # TODO: if an error occurs, log it so that we don't cause mypy to
108 # exit prematurely.
109 return self._handle_include(ctx)
111 def _handle_include(self, ctx: MethodContext) -> Type:
112 """Recursively remove Optional from a relational field of a model
113 if it was explicitly included.
115 An argument could be made that this is over-engineered
116 and while I do agree to an extent, the benefit of this
117 method over just setting the default value to an empty list
118 is that access to a relational field without explicitly
119 including it will raise an error when type checking, e.g
121 user = await client.user.find_unique(where={'id': user_id})
122 print('\n'.join(p.title for p in user.posts))
123 """
124 include_expr = self.get_arg_named('include', ctx)
125 if include_expr is None:
126 return ctx.default_return_type
128 if not isinstance(ctx.default_return_type, Instance):
129 # TODO: resolve this?
130 return ctx.default_return_type
132 is_coroutine = self.is_coroutine_type(ctx.default_return_type)
133 if is_coroutine:
134 actual_ret = ctx.default_return_type.args[2]
135 else:
136 actual_ret = ctx.default_return_type
138 is_optional = self.is_optional_type(actual_ret)
139 if is_optional:
140 actual_ret = cast(UnionType, actual_ret)
141 model_type = actual_ret.items[0]
142 else:
143 model_type = actual_ret
145 if not isinstance(model_type, Instance):
146 return ctx.default_return_type
148 try:
149 include = self.parse_expression_to_dict(include_expr)
150 new_model = self.modify_model_from_include(model_type, include)
151 except Exception as exc:
152 log.debug(
153 'Ignoring %s exception while parsing include: %s',
154 type(exc).__name__,
155 exc,
156 )
158 # TODO: test this, pytest-mypy-plugins does not bode well with multiple line output
159 if self.config.warn_parsing_errors:
160 # TODO: add more details
161 # e.g. "include" to "find_unique" of "UserActions"
162 if isinstance(exc, UnparsedExpression):
163 err_ctx = exc.context
164 else:
165 err_ctx = include_expr
167 error_unable_to_parse(
168 ctx.api,
169 err_ctx,
170 'the "include" argument',
171 )
173 return ctx.default_return_type
175 if is_optional:
176 actual_ret = cast(UnionType, actual_ret)
177 modified_ret = self.copy_modified_optional_type(actual_ret, new_model)
178 else:
179 modified_ret = new_model # type: ignore
181 if is_coroutine:
182 arg1, arg2, _ = ctx.default_return_type.args
183 return ctx.default_return_type.copy_modified(args=[arg1, arg2, modified_ret])
185 return modified_ret
187 def modify_model_from_include(self, model: Instance, data: Dict[Any, Any]) -> Instance:
188 names = model.type.names.copy()
189 for key, node in model.type.names.items():
190 names[key] = self.maybe_modify_included_field(key, node, data)
192 return self.copy_modified_instance(model, names)
194 def maybe_modify_included_field(
195 self,
196 key: Union[str, Expression, Node],
197 node: SymbolTableNode,
198 data: Dict[Any, Any],
199 ) -> SymbolTableNode:
200 value = data.get(key)
201 if value is False or value is None:
202 return node
204 if isinstance(value, (Expression, Node)):
205 raise UnparsedExpression(value)
207 # we do not want to remove the Optional from a field that is not a list
208 # as the Optional indicates that the field is optional on a database level
209 if (
210 not isinstance(node.node, Var)
211 or node.node.type is None
212 or not isinstance(node.node.type, UnionType)
213 or not self.is_optional_union_type(node.node.type)
214 or not self.is_list_type(node.node.type.items[0])
215 ):
216 log.debug(
217 'Not modifying included field: %s',
218 key,
219 )
220 return node
222 # this whole mess with copying is so that the modified field is not leaked
223 new = node.copy()
224 new.node = copy.copy(new.node)
225 assert isinstance(new.node, Var)
226 new.node.type = node.node.type.items[0]
228 if (
229 isinstance(value, dict)
230 and 'include' in value
231 and isinstance(new.node.type, Instance)
232 and isinstance(new.node.type.args[0], Instance)
233 ):
234 model = self.modify_model_from_include(new.node.type.args[0], value['include'])
235 new.node.type.args = (model, *new.node.type.args)
237 return new
239 def get_arg_named(self, name: str, ctx: MethodContext) -> Optional[Expression]:
240 """Return the expression for an argument."""
241 # keyword arguments
242 for i, names in enumerate(ctx.arg_names):
243 for j, arg_name in enumerate(names):
244 if arg_name == name:
245 return ctx.args[i][j]
247 # positional arguments
248 for i, arg_name in enumerate(ctx.callee_arg_names):
249 if arg_name == name and ctx.args[i]:
250 return ctx.args[i][0]
252 return None
254 def is_optional_type(self, typ: Type) -> bool:
255 return isinstance(typ, UnionType) and self.is_optional_union_type(typ)
257 def is_optional_union_type(self, typ: UnionType) -> bool:
258 return len(typ.items) == 2 and isinstance(typ.items[1], NoneType)
260 # TODO: why is fullname Any?
262 def is_coroutine_type(self, typ: Instance) -> bool:
263 return bool(typ.type.fullname == 'typing.Coroutine')
265 def is_list_type(self, typ: Type) -> bool:
266 return isinstance(typ, Instance) and typ.type.fullname == 'builtins.list'
268 def is_dict_call_type(self, expr: NameExpr) -> bool:
269 # statically wise, TypedDicts do not inherit from dict
270 # so we cannot check that, just checking if the expression
271 # inherits from a class that ends with dict is good enough
272 # for our use case
273 return bool(expr.fullname == 'builtins.dict') or bool(
274 isinstance(expr.node, TypeInfo)
275 and expr.node.bases
276 and expr.node.bases[0].type.fullname.lower().endswith('dict')
277 )
279 def copy_modified_instance(self, instance: Instance, names: SymbolTable) -> Instance:
280 new = copy.copy(instance)
281 new.type = TypeInfo(names, new.type.defn, new.type.module_name)
282 new.type.mro = [new.type, *instance.type.mro]
283 new.type.bases = instance.type.bases
284 new.type.metaclass_type = instance.type.metaclass_type
285 return new
287 def copy_modified_optional_type(self, original: UnionType, typ: Type) -> UnionType:
288 new = copy.copy(original)
289 new.items = new.items.copy()
290 new.items[0] = typ
291 return new
293 def parse_expression_to_dict(self, expression: Expression) -> Dict[Any, Any]:
294 if isinstance(expression, DictExpr):
295 return self._dictexpr_to_dict(expression)
297 if isinstance(expression, CallExpr):
298 return self._callexpr_to_dict(expression)
300 raise TypeError(f'Cannot parse expression of type={type(expression).__name__} to a dictionary.')
302 def _dictexpr_to_dict(self, expr: DictExpr) -> Dict[Any, Any]:
303 parsed = {}
304 for key_expr, value_expr in expr.items:
305 if key_expr is None:
306 # TODO: what causes this?
307 continue
309 key = self._resolve_expression(key_expr)
310 value = self._resolve_expression(value_expr)
311 parsed[key] = value
313 return parsed
315 def _callexpr_to_dict(self, expr: CallExpr, strict: bool = True) -> Dict[str, Any]:
316 if not isinstance(expr.callee, NameExpr):
317 raise TypeError(f'Expected CallExpr.callee to be a NameExpr but got {type(expr.callee)} instead.')
319 if strict and not self.is_dict_call_type(expr.callee):
320 raise TypeError(f'Expected builtins.dict to be called but got {expr.callee.fullname} instead')
322 parsed = {}
323 for arg_name, value_expr in zip(expr.arg_names, expr.args):
324 if arg_name is None:
325 continue
327 value = self._resolve_expression(value_expr)
328 parsed[arg_name] = value
330 return parsed
332 def _resolve_expression(self, expression: Expression) -> Any:
333 if isinstance(expression, (StrExpr, BytesExpr, IntExpr)):
334 return expression.value
336 if isinstance(expression, NameExpr):
337 return self._resolve_name_expression(expression)
339 if isinstance(expression, DictExpr):
340 return self._dictexpr_to_dict(expression)
342 if isinstance(expression, CallExpr):
343 return self._callexpr_to_dict(expression)
345 return expression
347 def _resolve_name_expression(self, expression: NameExpr) -> Any:
348 if isinstance(expression.node, Var):
349 return self._resolve_var_node(expression.node)
351 return expression
353 def _resolve_var_node(self, node: Var) -> Any:
354 if node.is_final:
355 return node.final_value
357 if node.fullname.startswith('builtins.'):
358 return self._resolve_builtin(node.fullname)
360 return node
362 def _resolve_builtin(self, fullname: str) -> Any:
363 return operator.attrgetter(*fullname.split('.')[1:])(builtins)
366class UnparsedExpression(Exception):
367 context: Union[Expression, Node]
369 def __init__(self, context: Union[Expression, Node]) -> None:
370 self.context = context
371 super().__init__(f'Tried to access a ({type(context).__name__}) expression that was not parsed.')
374ERROR_PARSING = ErrorCode('prisma-parsing', 'Unable to parse', 'Prisma')
377def error_unable_to_parse(api: CheckerPluginInterface, context: Context, detail: str) -> None:
378 link = 'https://github.com/RobertCraigie/prisma-client-py/issues/new/choose'
379 full_message = f'The prisma mypy plugin was unable to parse: {detail}\n'
380 full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
381 api.fail(full_message, context, code=ERROR_PARSING)