Coverage for src/prisma/_builder.py: 95%
367 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import json
4import decimal
5import inspect
6import logging
7import datetime
8from abc import ABC, abstractmethod
9from typing import TYPE_CHECKING, Any, Union, Mapping, Iterable, ForwardRef, cast
10from datetime import timezone
11from textwrap import indent
12from functools import singledispatch
13from typing_extensions import Literal, TypeGuard, override
15from pydantic import BaseModel
16from pydantic.fields import FieldInfo
18from . import fields
19from ._types import PrismaMethod
20from .errors import InvalidModelError, UnknownModelError, UnknownRelationalFieldError
21from ._compat import get_args, is_union, get_origin, model_fields, model_field_type
22from ._typing import is_list_type
23from ._constants import QUERY_BUILDER_ALIASES
25if TYPE_CHECKING:
26 from .bases import _PrismaModel as PrismaModel # noqa: TID251
27 from .types import Serializable # noqa: TID251
30log: logging.Logger = logging.getLogger(__name__)
32ChildType = Union['AbstractNode', str]
34ITERABLES: tuple[type[Any], ...] = (list, tuple, set)
36METHOD_OPERATION_MAPPING: dict[PrismaMethod, Operation] = {
37 'create': 'mutation',
38 'delete': 'mutation',
39 'update': 'mutation',
40 'upsert': 'mutation',
41 'query_raw': 'mutation',
42 'query_first': 'mutation',
43 'create_many': 'mutation',
44 'execute_raw': 'mutation',
45 'delete_many': 'mutation',
46 'update_many': 'mutation',
47 'count': 'query',
48 'group_by': 'query',
49 'find_many': 'query',
50 'find_first': 'query',
51 'find_first_or_raise': 'query',
52 'find_unique': 'query',
53 'find_unique_or_raise': 'query',
54}
56METHOD_FORMAT_MAPPING: dict[PrismaMethod, str] = {
57 'create': 'createOne{model}',
58 'delete': 'deleteOne{model}',
59 'update': 'updateOne{model}',
60 'upsert': 'upsertOne{model}',
61 'query_raw': 'queryRaw',
62 'query_first': 'queryRaw',
63 'create_many': 'createMany{model}',
64 'execute_raw': 'executeRaw',
65 'delete_many': 'deleteMany{model}',
66 'update_many': 'updateMany{model}',
67 'count': 'aggregate{model}',
68 'group_by': 'groupBy{model}',
69 'find_many': 'findMany{model}',
70 'find_first': 'findFirst{model}',
71 'find_first_or_raise': 'findFirst{model}OrThrow',
72 'find_unique': 'findUnique{model}',
73 'find_unique_or_raise': 'findUnique{model}OrThrow',
74}
76MISSING = object()
77Operation = Literal['query', 'mutation']
80class QueryBuilder:
81 method: PrismaMethod
82 """The name of the actions method that this query is for, e.g. `find_unique`"""
84 method_format: str
85 """Template denoting how the internal method name should be constructed, e.g. `findUnique{model}`"""
87 operation: Operation
88 """The GraphQL operatiom e.g. `query`, `mutation`"""
90 model: type[PrismaModel] | None
91 """The Pydantic model that will be used to parse the response.
93 Used to extract the model name & build selections.
94 """
96 include: dict[str, Any] | None
97 """Mapping of relational fields to include in the result"""
99 arguments: dict[str, Any]
100 """Arguments to pass to the query"""
102 root_selection: list[str] | None
103 """List of fields to select"""
105 prisma_models: set[str]
106 """The names of all models present in the schema.prisma"""
108 relational_field_mappings: dict[str, dict[str, str]]
109 """A mapping of model name to a mapping of field name to relational model name
111 e.g. {'User': {'posts': 'Post'}}
112 """
114 __slots__ = (
115 'method',
116 'method_format',
117 'operation',
118 'model',
119 'include',
120 'arguments',
121 'root_selection',
122 'prisma_models',
123 'relational_field_mappings',
124 )
126 def __init__(
127 self,
128 *,
129 method: PrismaMethod,
130 arguments: dict[str, Any],
131 prisma_models: set[str],
132 relational_field_mappings: dict[str, dict[str, str]],
133 model: type[BaseModel] | None = None,
134 root_selection: list[str] | None = None,
135 ) -> None:
136 self.method = method
137 self.method_format = METHOD_FORMAT_MAPPING[method]
138 self.operation = METHOD_OPERATION_MAPPING[method]
139 self.root_selection = root_selection
140 self.prisma_models = prisma_models
141 self.relational_field_mappings = relational_field_mappings
142 self.arguments = args = self._transform_aliases(arguments)
143 self.include = args.pop('include', None)
145 # Note: we ignore the `model` argument for raw queries as users may want to pass in a model
146 # that isn't a `PrismaModel` because they've defined it manually & enforcing that
147 # they subclass `PrismaModel` doesn't bring any real benefits.
148 if model is None or method in {'execute_raw', 'query_raw', 'query_first'}:
149 self.model = None
150 else:
151 if not _is_prisma_model_type(model) or not hasattr(model, '__prisma_model__'): 151 ↛ 152line 151 didn't jump to line 152, because the condition on line 151 was never true
152 raise InvalidModelError(model)
154 self.model = model
156 def build(self) -> str:
157 """Build the payload that should be sent to the QueryEngine"""
158 data: dict[str, object] = {
159 'variables': {},
160 'operation_name': self.operation,
161 'query': self.build_query(),
162 }
163 return dumps(data)
165 def build_query(self) -> str:
166 """Build the GraphQL query
168 Example query:
170 query {
171 result: findUniqueUser
172 (
173 where: {
174 id: "ckq23ky3003510r8zll5m2hma"
175 }
176 )
177 {
178 id
179 name
180 profile {
181 id
182 user_id
183 bio
184 }
185 }
186 }
187 """
188 query = self._create_root_node().render()
189 log.debug('Generated query: \n%s', query)
190 return query
192 def _create_root_node(self) -> 'RootNode':
193 root = RootNode(builder=self)
194 root.add(ResultNode.create(self))
195 root.add(
196 Selection.create(
197 self,
198 model=self.model,
199 include=self.include,
200 root_selection=self.root_selection,
201 )
202 )
203 return root
205 def get_default_fields(self, model: type[PrismaModel]) -> list[str]:
206 """Returns a list of all the scalar fields of a model
208 Raises UnknownModelError if the current model cannot be found.
209 """
210 name = getattr(model, '__prisma_model__', MISSING)
211 if name is MISSING: 211 ↛ 212line 211 didn't jump to line 212, because the condition on line 211 was never true
212 raise InvalidModelError(model)
214 name = model.__prisma_model__
215 if name not in self.prisma_models:
216 raise UnknownModelError(name)
218 # by default we exclude every field that points to a PrismaModel as that indicates that it is a relational field
219 # we explicitly keep fields that point to anything else, even other pydantic.BaseModel types, as they can be used to deserialize JSON
220 return [
221 field
222 for field, info in model_fields(model).items()
223 if not _field_is_prisma_model(info, name=field, parent=model)
224 ]
226 def get_relational_model(self, current_model: type[PrismaModel], field: str) -> type[PrismaModel]:
227 """Returns the model that the field is related to.
229 Raises UnknownModelError if the current model is invalid.
230 Raises UnknownRelationalFieldError if the field does not exist.
231 """
232 name = getattr(current_model, '__prisma_model__', MISSING)
233 if name is MISSING: 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true
234 raise InvalidModelError(current_model)
236 name = cast(str, name)
238 try:
239 mappings = self.relational_field_mappings[name]
240 except KeyError as exc:
241 raise UnknownModelError(name) from exc
243 if field not in mappings:
244 raise UnknownRelationalFieldError(model=current_model.__name__, field=field)
246 try:
247 info = model_fields(current_model)[field]
248 except KeyError as exc:
249 raise UnknownRelationalFieldError(model=current_model.__name__, field=field) from exc
251 model = _prisma_model_for_field(info, name=field, parent=current_model)
252 if not model: 252 ↛ 253line 252 didn't jump to line 253, because the condition on line 252 was never true
253 raise RuntimeError(
254 f"The `{field}` field doesn't appear to be a Prisma Model type. "
255 + 'Is the field a pydantic.BaseModel type and does it have a `__prisma_model__` class variable?'
256 )
258 return model
260 def _transform_aliases(self, arguments: dict[str, Any]) -> dict[str, Any]:
261 """Transform dict keys to match global aliases
263 e.g. order_by -> orderBy
264 """
265 transformed = dict()
266 for key, value in arguments.items():
267 alias = QUERY_BUILDER_ALIASES.get(key, key)
268 if isinstance(value, dict):
269 transformed[alias] = self._transform_aliases(arguments=value)
270 elif isinstance(value, ITERABLES):
271 # it is safe to map any iterable type to a list here as it is only being used
272 # to serialise the query and we only officially support lists anyway
273 transformed[alias] = [
274 self._transform_aliases(data) if isinstance(data, dict) else data # type: ignore
275 for data in value
276 ]
277 else:
278 transformed[alias] = value
279 return transformed
282def _prisma_model_for_field(
283 field: FieldInfo,
284 *,
285 name: str,
286 parent: type[BaseModel],
287) -> type[PrismaModel] | None:
288 cls_name = parent.__name__
289 type_ = model_field_type(field)
290 if type_ is None: 290 ↛ 291line 290 didn't jump to line 291, because the condition on line 290 was never true
291 raise RuntimeError(f'Unexpected field type is None for {cls_name}.{name}')
293 types: Iterable[type]
294 if is_union(get_origin(type_)):
295 types = get_args(type_)
296 else:
297 types = [type_]
299 for type_ in types:
300 if isinstance(type_, ForwardRef):
301 raise RuntimeError(
302 f'Encountered forward reference for {cls_name}.{name}; Forward references must be evaluated using {cls_name}.update_forward_refs()'
303 )
305 if is_list_type(type_) and type_ is not None:
306 type_ = get_args(type_)[0]
308 if hasattr(type_, '__prisma_model__'):
309 return type_
311 return None
314def _field_is_prisma_model(field: FieldInfo, *, name: str, parent: type[BaseModel]) -> bool:
315 """Whether or not the given field info represents a model at the database level.
317 This will return `True` for cases where the field represents a list of models or a single model.
318 """
319 return _prisma_model_for_field(field, name=name, parent=parent) is not None
322def _is_prisma_model_type(type_: type[BaseModel]) -> TypeGuard[type[PrismaModel]]:
323 from .bases import _PrismaModel # noqa: TID251
325 return issubclass(type_, _PrismaModel)
328class AbstractNode(ABC):
329 __slots__ = ()
331 @abstractmethod
332 def render(self) -> str | None:
333 """Render the node to a string
335 None is returned if the node should not be rendered.
336 """
337 ...
339 def should_render(self) -> bool:
340 """If True, rendering of the node is skipped
342 Useful for some nodes as they should only actually
343 be rendered if they have any children.
344 """
345 return True
348class Node(AbstractNode):
349 """Base node handling rendering of child nodes"""
351 joiner: str
352 indent: str
353 builder: QueryBuilder
354 children: list[ChildType]
356 __slots__ = (
357 'joiner',
358 'indent',
359 'builder',
360 'children',
361 )
363 def __init__(
364 self, builder: QueryBuilder, *, joiner: str = '\n', indent: str = ' ', children: list[ChildType] | None = None
365 ) -> None:
366 self.builder = builder
367 self.joiner = joiner
368 self.indent = indent
369 self.children = children if children is not None else []
371 def enter(self) -> str | None:
372 """Get the string used to enter the node.
374 This string will be rendered *before* the children.
375 """
376 return None
378 def depart(self) -> str | None:
379 """Get the string used to depart the node.
381 This string will be rendered *after* the children.
382 """
383 return None
385 @override
386 def render(self) -> str | None:
387 """Render the node and it's children and to string.
389 Rendering a node involves 4 steps:
391 1. Entering the node
392 2. Rendering it's children
393 3. Departing the node
394 4. Joining the previous steps together into a single string
395 """
396 if not self.should_render():
397 return None
399 strings: list[str] = []
400 entered = self.enter()
401 if entered is not None: 401 ↛ 404line 401 didn't jump to line 404, because the condition on line 401 was never false
402 strings.append(entered)
404 for child in self.children:
405 content: str | None = None
407 if isinstance(child, str):
408 content = child
409 else:
410 content = child.render()
412 if content:
413 strings.append(indent(content, self.indent))
415 departed = self.depart()
416 if departed is not None:
417 strings.append(departed)
419 return self.joiner.join(strings)
421 def add(self, child: ChildType) -> None:
422 """Add a child"""
423 self.children.append(child)
425 def create_children(self) -> list[ChildType]:
426 """Create the node's children
428 If children are passed to the constructor, the children
429 returned from this method are used to extend the already
430 set children.
431 """
432 return []
434 @classmethod
435 def create(cls, builder: QueryBuilder | None = None, **kwargs: Any) -> 'Node':
436 """Create the node and its children
438 This is useful for subclasses that add extra attributes in __init__
439 """
440 kwargs.setdefault('builder', builder)
441 node = cls(**kwargs)
442 node.children.extend(node.create_children())
443 return node
446class RootNode(Node):
447 """Rendered node examples:
449 query {
450 <children>
451 }
453 or
455 mutation {
456 <children>
457 }
458 """
460 __slots__ = ()
462 @override
463 def enter(self) -> str:
464 return f'{self.builder.operation} {{'
466 @override
467 def depart(self) -> str:
468 return '}'
470 @override
471 def render(self) -> str:
472 content = super().render()
473 if not content: # pragma: no cover
474 # this should never happen.
475 # render() is typed to return None if the node
476 # should not be rendered but as this node will
477 # always be rendered it should always return
478 # a non-empty string.
479 raise RuntimeError('Could not generate query.')
480 return content
483class ResultNode(Node):
484 """Rendered node examples:
486 result: findUniqueUser
487 <children>
489 or
491 result: executeRaw
492 <children>
493 """
495 __slots__ = ()
497 def __init__(self, indent: str = '', **kwargs: Any) -> None:
498 super().__init__(indent=indent, **kwargs)
500 @override
501 def enter(self) -> str:
502 model = self.builder.model
503 if model is not None:
504 model_name = model.__prisma_model__
505 else:
506 model_name = ''
508 method = self.builder.method_format.format(model=model_name)
509 return f'result: {method}'
511 @override
512 def depart(self) -> str | None:
513 return None
515 @override
516 def create_children(self) -> list[ChildType]:
517 return [
518 Arguments.create(
519 self.builder,
520 arguments=self.builder.arguments,
521 )
522 ]
525class Arguments(Node):
526 """Rendered node example:
528 (
529 key1: "1"
530 key2: "[\"John\",\"123\"]"
531 key3: true
532 key4: {
533 data: true
534 }
535 )
536 """
538 arguments: dict[str, Any]
540 __slots__ = ('arguments',)
542 def __init__(self, arguments: dict[str, Any], **kwargs: Any) -> None:
543 super().__init__(**kwargs)
544 self.arguments = arguments
546 @override
547 def should_render(self) -> bool:
548 return bool(self.children)
550 @override
551 def enter(self) -> str:
552 return '('
554 @override
555 def depart(self) -> str:
556 return ')'
558 @override
559 def create_children(self, arguments: dict[str, Any] | None = None) -> list[ChildType]:
560 children: list[ChildType] = []
562 for arg, value in self.arguments.items():
563 if value is None:
564 # ignore None values for convenience
565 continue
567 if isinstance(value, dict):
568 children.append(Key(arg, node=Data.create(self.builder, data=value)))
569 elif isinstance(value, ITERABLES):
570 # NOTE: we have a special case for execute_raw, query_raw and query_first
571 # here as prisma expects parameters to be passed as a json string
572 # value like "[\"John\",\"123\"]", and we encode twice to ensure
573 # that only the inner quotes are escaped
574 if self.builder.method in {'query_raw', 'query_first', 'execute_raw'}:
575 children.append(f'{arg}: {dumps(dumps(value))}')
576 else:
577 children.append(Key(arg, node=ListNode.create(self.builder, data=value)))
578 else:
579 children.append(f'{arg}: {dumps(value)}')
581 return children
584class Data(Node):
585 """Rendered node example:
587 {
588 key1: "a"
589 key2: 3
590 key3: [
591 "name"
592 ]
593 }
594 """
596 data: Mapping[str, Any]
598 __slots__ = ('data',)
600 def __init__(self, data: Mapping[str, Any], **kwargs: Any) -> None:
601 super().__init__(**kwargs)
602 self.data = data
604 @override
605 def enter(self) -> str:
606 return '{'
608 @override
609 def depart(self) -> str:
610 return '}'
612 @override
613 def create_children(self) -> list[ChildType]:
614 children: list[ChildType] = []
616 for key, value in self.data.items():
617 if isinstance(value, dict):
618 children.append(Key(key, node=Data.create(self.builder, data=value)))
619 elif isinstance(value, (list, tuple, set)):
620 children.append(Key(key, node=ListNode.create(self.builder, data=value)))
621 else:
622 children.append(f'{key}: {dumps(value)}')
624 return children
627class ListNode(Node):
628 data: Iterable[Any]
630 __slots__ = ('data',)
632 def __init__(self, data: Iterable[Any], joiner: str = ',\n', **kwargs: Any) -> None:
633 super().__init__(joiner=joiner, **kwargs)
634 self.data = data
636 @override
637 def enter(self) -> str:
638 return '['
640 @override
641 def depart(self) -> str:
642 return ']'
644 @override
645 def create_children(self) -> list[ChildType]:
646 children: list[ChildType] = []
648 for item in self.data:
649 if isinstance(item, dict):
650 children.append(Data.create(self.builder, data=item))
651 else:
652 children.append(dumps(item))
654 return children
657class Selection(Node):
658 """Represents field selections
660 Example no include:
662 {
663 id
664 name
665 }
667 Example include={'posts': True}
669 {
670 id
671 name
672 posts {
673 id
674 title
675 }
676 }
678 Example include={'posts': {'where': {'title': {'contains': 'Test'}}}}
680 {
681 id
682 name
683 posts(
684 where: {
685 title: {
686 contains: 'Test'
687 }
688 }
689 )
690 {
691 id
692 title
693 }
694 }
695 """
697 model: type[PrismaModel] | None
698 include: dict[str, Any] | None
699 root_selection: list[str] | None
701 __slots__ = (
702 'model',
703 'include',
704 'root_selection',
705 )
707 def __init__(
708 self,
709 model: type[PrismaModel] | None = None,
710 include: dict[str, Any] | None = None,
711 root_selection: list[str] | None = None,
712 **kwargs: Any,
713 ) -> None:
714 super().__init__(**kwargs)
715 self.model = model
716 self.include = include
717 self.root_selection = root_selection
719 @override
720 def should_render(self) -> bool:
721 return bool(self.children)
723 @override
724 def enter(self) -> str:
725 return '{'
727 @override
728 def depart(self) -> str:
729 return '}'
731 @override
732 def create_children(self) -> list[ChildType]:
733 model = self.model
734 include = self.include
735 builder = self.builder
736 children: list[ChildType] = []
738 # root_selection, if present overrides the default fields
739 # for a model as it is used by methods such as count()
740 # that do not support returning model fields
741 root_selection = self.root_selection
742 if root_selection is not None:
743 children.extend(root_selection)
744 elif model is not None:
745 children.extend(builder.get_default_fields(model))
747 if include is not None:
748 if model is None:
749 raise ValueError('Cannot include fields when model is None.')
751 for key, value in include.items():
752 if value is True:
753 # e.g. posts { post_fields }
754 children.append(
755 Key(
756 key,
757 sep=' ',
758 node=Selection.create(
759 builder,
760 include=None,
761 model=builder.get_relational_model(current_model=model, field=key),
762 ),
763 )
764 )
765 elif isinstance(value, dict): 765 ↛ 784line 765 didn't jump to line 784, because the condition on line 765 was never false
766 # e.g. given {'posts': {where': {'published': True}}} return
767 # posts( where: { published: true }) { post_fields }
768 args = value.copy()
769 nested_include = args.pop('include', None)
770 children.extend(
771 [
772 Key(
773 key,
774 sep='',
775 node=Arguments.create(builder, arguments=args),
776 ),
777 Selection.create(
778 builder,
779 include=nested_include,
780 model=builder.get_relational_model(current_model=model, field=key),
781 ),
782 ]
783 )
784 elif value is False:
785 continue
786 else:
787 raise TypeError(f'Expected `bool` or `dict` include value but got {type(value)} instead.')
789 return children
792class Key(AbstractNode):
793 """Node for rendering a child node with a prefixed key"""
795 key: str
796 sep: str
797 node: Node
799 __slots__ = (
800 'key',
801 'sep',
802 'node',
803 )
805 def __init__(self, key: str, node: Node, sep: str = ': ') -> None:
806 self.key = key
807 self.node = node
808 self.sep = sep
810 @override
811 def render(self) -> str:
812 content = self.node.render()
813 if content:
814 return f'{self.key}{self.sep}{content}'
815 return f'{self.key}{self.sep}'
818@singledispatch
819def serializer(obj: Any) -> Serializable:
820 """Single dispatch generic function for serializing objects to JSON"""
821 if inspect.isclass(obj):
822 typ = obj
823 else:
824 typ = type(obj)
826 raise TypeError(f'Type {typ} not serializable')
829@serializer.register(datetime.datetime)
830def serialize_datetime(dt: datetime.datetime) -> str:
831 """Format a datetime object to an ISO8601 string with a timezone.
833 This assumes naive datetime objects are in UTC.
834 """
835 if dt.tzinfo is None:
836 dt = dt.replace(tzinfo=timezone.utc)
837 elif dt.tzinfo != timezone.utc:
838 dt = dt.astimezone(timezone.utc)
840 # truncate microseconds to 3 decimal places
841 # https://github.com/RobertCraigie/prisma-client-py/issues/129
842 dt = dt.replace(microsecond=int(dt.microsecond / 1000) * 1000)
843 return dt.isoformat()
846@serializer.register(fields.Json)
847def serialize_json(obj: fields.Json) -> str:
848 """Serialize a Json wrapper to a json string.
850 This is used as a hook to override our default behaviour when building
851 queries which would treat data like {'hello': 'world'} as a Data node
852 when we instead want it to be rendered as a raw json string.
854 This should only be used for fields that are of the `Json` type.
855 """
856 return dumps(obj.data)
859@serializer.register(fields.Base64)
860def serialize_base64(obj: fields.Base64) -> str:
861 """Serialize a Base64 wrapper object to raw binary data"""
862 return str(obj)
865@serializer.register(decimal.Decimal)
866def serialize_decimal(obj: decimal.Decimal) -> str:
867 """Serialize a Decimal object to a string"""
868 return str(obj)
871def dumps(obj: Any, **kwargs: Any) -> str:
872 kwargs.setdefault('default', serializer)
873 kwargs.setdefault('ensure_ascii', False)
874 return json.dumps(obj, **kwargs)
877# black does not respect the fmt: off comment without this
878# fmt: on