Coverage for src/prisma/mypy.py: 0%

192 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1import re 

2import copy 

3import logging 

4import builtins 

5import operator 

6from typing import ( 

7 Any, 

8 Dict, 

9 Type as TypingType, 

10 Union, 

11 Callable, 

12 Optional, 

13 cast, 

14) 

15from configparser import ConfigParser 

16from typing_extensions import override 

17 

18from mypy.nodes import ( 

19 Var, 

20 Node, 

21 Context, 

22 IntExpr, 

23 StrExpr, 

24 CallExpr, 

25 DictExpr, 

26 NameExpr, 

27 TypeInfo, 

28 BytesExpr, 

29 Expression, 

30 SymbolTable, 

31 SymbolTableNode, 

32) 

33from mypy.types import ( 

34 Type, 

35 Instance, 

36 NoneType, 

37 UnionType, 

38) 

39from mypy.plugin import Plugin, MethodContext, CheckerPluginInterface 

40from mypy.options import Options 

41from mypy.errorcodes import ErrorCode 

42 

43# match any direct children of an actions class 

44CLIENT_ACTION_CHILD = re.compile(r'prisma\.actions\.(.*)Actions\.(?P<name>(((?!\.).)*$))') 

45ACTIONS = [ 

46 'create', 

47 'find_unique', 

48 'delete', 

49 'update', 

50 'find_first', 

51 'find_many', 

52 'upsert', 

53 'update_many', 

54 'delete_many', 

55 'count', 

56] 

57 

58CONFIGFILE_KEY = 'prisma-mypy' 

59 

60log: logging.Logger = logging.getLogger(__name__) 

61 

62 

63# due to the way the mypy API is typed we unfortunately have to disable Pyright type checks 

64# this is because mypy type hints are written like this: `Bogus[str]` instead of `str` 

65# mypy uses internal magic to transform Bogus[T] to T which pyright cannot understand. 

66# pyright: reportGeneralTypeIssues=false, reportUnnecessaryComparison=false 

67 

68 

69def plugin(version: str) -> TypingType[Plugin]: # noqa: ARG001 

70 return PrismaPlugin 

71 

72 

73class PrismaPluginConfig: 

74 __slots__ = ('warn_parsing_errors',) 

75 warn_parsing_errors: bool 

76 

77 def __init__(self, options: Options) -> None: 

78 if options.config_file is None: # pragma: no cover 

79 return 

80 

81 plugin_config = ConfigParser() 

82 plugin_config.read(options.config_file) 

83 for key in self.__slots__: 

84 setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=True) 

85 setattr(self, key, setting) 

86 

87 

88class PrismaPlugin(Plugin): 

89 config: PrismaPluginConfig 

90 

91 def __init__(self, options: Options) -> None: 

92 self.config = PrismaPluginConfig(options) 

93 super().__init__(options) 

94 

95 @override 

96 def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: 

97 match = CLIENT_ACTION_CHILD.match(fullname) 

98 if not match: 

99 return None 

100 

101 if match.group('name') in ACTIONS: 

102 return self.handle_action_invocation 

103 

104 return None 

105 

106 def handle_action_invocation(self, ctx: MethodContext) -> Type: 

107 # TODO: if an error occurs, log it so that we don't cause mypy to 

108 # exit prematurely. 

109 return self._handle_include(ctx) 

110 

111 def _handle_include(self, ctx: MethodContext) -> Type: 

112 """Recursively remove Optional from a relational field of a model 

113 if it was explicitly included. 

114 

115 An argument could be made that this is over-engineered 

116 and while I do agree to an extent, the benefit of this 

117 method over just setting the default value to an empty list 

118 is that access to a relational field without explicitly 

119 including it will raise an error when type checking, e.g 

120 

121 user = await client.user.find_unique(where={'id': user_id}) 

122 print('\n'.join(p.title for p in user.posts)) 

123 """ 

124 include_expr = self.get_arg_named('include', ctx) 

125 if include_expr is None: 

126 return ctx.default_return_type 

127 

128 if not isinstance(ctx.default_return_type, Instance): 

129 # TODO: resolve this? 

130 return ctx.default_return_type 

131 

132 is_coroutine = self.is_coroutine_type(ctx.default_return_type) 

133 if is_coroutine: 

134 actual_ret = ctx.default_return_type.args[2] 

135 else: 

136 actual_ret = ctx.default_return_type 

137 

138 is_optional = self.is_optional_type(actual_ret) 

139 if is_optional: 

140 actual_ret = cast(UnionType, actual_ret) 

141 model_type = actual_ret.items[0] 

142 else: 

143 model_type = actual_ret 

144 

145 if not isinstance(model_type, Instance): 

146 return ctx.default_return_type 

147 

148 try: 

149 include = self.parse_expression_to_dict(include_expr) 

150 new_model = self.modify_model_from_include(model_type, include) 

151 except Exception as exc: 

152 log.debug( 

153 'Ignoring %s exception while parsing include: %s', 

154 type(exc).__name__, 

155 exc, 

156 ) 

157 

158 # TODO: test this, pytest-mypy-plugins does not bode well with multiple line output 

159 if self.config.warn_parsing_errors: 

160 # TODO: add more details 

161 # e.g. "include" to "find_unique" of "UserActions" 

162 if isinstance(exc, UnparsedExpression): 

163 err_ctx = exc.context 

164 else: 

165 err_ctx = include_expr 

166 

167 error_unable_to_parse( 

168 ctx.api, 

169 err_ctx, 

170 'the "include" argument', 

171 ) 

172 

173 return ctx.default_return_type 

174 

175 if is_optional: 

176 actual_ret = cast(UnionType, actual_ret) 

177 modified_ret = self.copy_modified_optional_type(actual_ret, new_model) 

178 else: 

179 modified_ret = new_model # type: ignore 

180 

181 if is_coroutine: 

182 arg1, arg2, _ = ctx.default_return_type.args 

183 return ctx.default_return_type.copy_modified(args=[arg1, arg2, modified_ret]) 

184 

185 return modified_ret 

186 

187 def modify_model_from_include(self, model: Instance, data: Dict[Any, Any]) -> Instance: 

188 names = model.type.names.copy() 

189 for key, node in model.type.names.items(): 

190 names[key] = self.maybe_modify_included_field(key, node, data) 

191 

192 return self.copy_modified_instance(model, names) 

193 

194 def maybe_modify_included_field( 

195 self, 

196 key: Union[str, Expression, Node], 

197 node: SymbolTableNode, 

198 data: Dict[Any, Any], 

199 ) -> SymbolTableNode: 

200 value = data.get(key) 

201 if value is False or value is None: 

202 return node 

203 

204 if isinstance(value, (Expression, Node)): 

205 raise UnparsedExpression(value) 

206 

207 # we do not want to remove the Optional from a field that is not a list 

208 # as the Optional indicates that the field is optional on a database level 

209 if ( 

210 not isinstance(node.node, Var) 

211 or node.node.type is None 

212 or not isinstance(node.node.type, UnionType) 

213 or not self.is_optional_union_type(node.node.type) 

214 or not self.is_list_type(node.node.type.items[0]) 

215 ): 

216 log.debug( 

217 'Not modifying included field: %s', 

218 key, 

219 ) 

220 return node 

221 

222 # this whole mess with copying is so that the modified field is not leaked 

223 new = node.copy() 

224 new.node = copy.copy(new.node) 

225 assert isinstance(new.node, Var) 

226 new.node.type = node.node.type.items[0] 

227 

228 if ( 

229 isinstance(value, dict) 

230 and 'include' in value 

231 and isinstance(new.node.type, Instance) 

232 and isinstance(new.node.type.args[0], Instance) 

233 ): 

234 model = self.modify_model_from_include(new.node.type.args[0], value['include']) 

235 new.node.type.args = (model, *new.node.type.args) 

236 

237 return new 

238 

239 def get_arg_named(self, name: str, ctx: MethodContext) -> Optional[Expression]: 

240 """Return the expression for an argument.""" 

241 # keyword arguments 

242 for i, names in enumerate(ctx.arg_names): 

243 for j, arg_name in enumerate(names): 

244 if arg_name == name: 

245 return ctx.args[i][j] 

246 

247 # positional arguments 

248 for i, arg_name in enumerate(ctx.callee_arg_names): 

249 if arg_name == name and ctx.args[i]: 

250 return ctx.args[i][0] 

251 

252 return None 

253 

254 def is_optional_type(self, typ: Type) -> bool: 

255 return isinstance(typ, UnionType) and self.is_optional_union_type(typ) 

256 

257 def is_optional_union_type(self, typ: UnionType) -> bool: 

258 return len(typ.items) == 2 and isinstance(typ.items[1], NoneType) 

259 

260 # TODO: why is fullname Any? 

261 

262 def is_coroutine_type(self, typ: Instance) -> bool: 

263 return bool(typ.type.fullname == 'typing.Coroutine') 

264 

265 def is_list_type(self, typ: Type) -> bool: 

266 return isinstance(typ, Instance) and typ.type.fullname == 'builtins.list' 

267 

268 def is_dict_call_type(self, expr: NameExpr) -> bool: 

269 # statically wise, TypedDicts do not inherit from dict 

270 # so we cannot check that, just checking if the expression 

271 # inherits from a class that ends with dict is good enough 

272 # for our use case 

273 return bool(expr.fullname == 'builtins.dict') or bool( 

274 isinstance(expr.node, TypeInfo) 

275 and expr.node.bases 

276 and expr.node.bases[0].type.fullname.lower().endswith('dict') 

277 ) 

278 

279 def copy_modified_instance(self, instance: Instance, names: SymbolTable) -> Instance: 

280 new = copy.copy(instance) 

281 new.type = TypeInfo(names, new.type.defn, new.type.module_name) 

282 new.type.mro = [new.type, *instance.type.mro] 

283 new.type.bases = instance.type.bases 

284 new.type.metaclass_type = instance.type.metaclass_type 

285 return new 

286 

287 def copy_modified_optional_type(self, original: UnionType, typ: Type) -> UnionType: 

288 new = copy.copy(original) 

289 new.items = new.items.copy() 

290 new.items[0] = typ 

291 return new 

292 

293 def parse_expression_to_dict(self, expression: Expression) -> Dict[Any, Any]: 

294 if isinstance(expression, DictExpr): 

295 return self._dictexpr_to_dict(expression) 

296 

297 if isinstance(expression, CallExpr): 

298 return self._callexpr_to_dict(expression) 

299 

300 raise TypeError(f'Cannot parse expression of type={type(expression).__name__} to a dictionary.') 

301 

302 def _dictexpr_to_dict(self, expr: DictExpr) -> Dict[Any, Any]: 

303 parsed = {} 

304 for key_expr, value_expr in expr.items: 

305 if key_expr is None: 

306 # TODO: what causes this? 

307 continue 

308 

309 key = self._resolve_expression(key_expr) 

310 value = self._resolve_expression(value_expr) 

311 parsed[key] = value 

312 

313 return parsed 

314 

315 def _callexpr_to_dict(self, expr: CallExpr, strict: bool = True) -> Dict[str, Any]: 

316 if not isinstance(expr.callee, NameExpr): 

317 raise TypeError(f'Expected CallExpr.callee to be a NameExpr but got {type(expr.callee)} instead.') 

318 

319 if strict and not self.is_dict_call_type(expr.callee): 

320 raise TypeError(f'Expected builtins.dict to be called but got {expr.callee.fullname} instead') 

321 

322 parsed = {} 

323 for arg_name, value_expr in zip(expr.arg_names, expr.args): 

324 if arg_name is None: 

325 continue 

326 

327 value = self._resolve_expression(value_expr) 

328 parsed[arg_name] = value 

329 

330 return parsed 

331 

332 def _resolve_expression(self, expression: Expression) -> Any: 

333 if isinstance(expression, (StrExpr, BytesExpr, IntExpr)): 

334 return expression.value 

335 

336 if isinstance(expression, NameExpr): 

337 return self._resolve_name_expression(expression) 

338 

339 if isinstance(expression, DictExpr): 

340 return self._dictexpr_to_dict(expression) 

341 

342 if isinstance(expression, CallExpr): 

343 return self._callexpr_to_dict(expression) 

344 

345 return expression 

346 

347 def _resolve_name_expression(self, expression: NameExpr) -> Any: 

348 if isinstance(expression.node, Var): 

349 return self._resolve_var_node(expression.node) 

350 

351 return expression 

352 

353 def _resolve_var_node(self, node: Var) -> Any: 

354 if node.is_final: 

355 return node.final_value 

356 

357 if node.fullname.startswith('builtins.'): 

358 return self._resolve_builtin(node.fullname) 

359 

360 return node 

361 

362 def _resolve_builtin(self, fullname: str) -> Any: 

363 return operator.attrgetter(*fullname.split('.')[1:])(builtins) 

364 

365 

366class UnparsedExpression(Exception): 

367 context: Union[Expression, Node] 

368 

369 def __init__(self, context: Union[Expression, Node]) -> None: 

370 self.context = context 

371 super().__init__(f'Tried to access a ({type(context).__name__}) expression that was not parsed.') 

372 

373 

374ERROR_PARSING = ErrorCode('prisma-parsing', 'Unable to parse', 'Prisma') 

375 

376 

377def error_unable_to_parse(api: CheckerPluginInterface, context: Context, detail: str) -> None: 

378 link = 'https://github.com/RobertCraigie/prisma-client-py/issues/new/choose' 

379 full_message = f'The prisma mypy plugin was unable to parse: {detail}\n' 

380 full_message += f'Please consider reporting this bug at {link} so we can try to fix it!' 

381 api.fail(full_message, context, code=ERROR_PARSING)