Coverage for src/prisma/generator/schema.py: 95%
100 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from enum import Enum
2from typing import Any, Dict, List, Type, Tuple, Union, Optional
3from typing_extensions import ClassVar
5from pydantic import BaseModel
7from .utils import to_constant_case
8from .models import Model as ModelInfo, AnyData, PrimaryKey, DMMFEnumType, data_ctx
9from .._compat import (
10 PYDANTIC_V2,
11 ConfigDict,
12 model_rebuild,
13 root_validator,
14 cached_property,
15)
18class Kind(str, Enum):
19 alias = 'alias'
20 union = 'union'
21 typeddict = 'typeddict'
22 enum = 'enum'
25class PrismaType(BaseModel):
26 name: str
27 kind: Kind
28 subtypes: List['PrismaType'] = []
30 @classmethod
31 def from_variants(cls, variants: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']:
32 """Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of variants"""
33 if len(variants) > 1:
34 return PrismaUnion(variants=variants, **kwargs)
36 return PrismaAlias(subtypes=variants, **kwargs)
39class PrismaDict(PrismaType):
40 kind: Kind = Kind.typeddict
41 fields: Dict[str, str]
42 total: bool
45class PrismaUnion(PrismaType):
46 kind: Kind = Kind.union
47 variants: List[PrismaType]
49 @root_validator(pre=True)
50 @classmethod
51 def add_subtypes(cls, values: Dict[str, Any]) -> Dict[str, Any]:
52 # add all variants as subtypes so that we don't have to special
53 # case rendering subtypes for unions
54 if 'variants' in values: 54 ↛ 58line 54 didn't jump to line 58, because the condition on line 54 was never false
55 subtypes = values.get('subtypes', [])
56 subtypes.extend(values['variants'])
57 values['subtypes'] = subtypes
58 return values
61class PrismaEnum(PrismaType):
62 kind: Kind = Kind.enum
63 members: List[Tuple[str, str]]
66class PrismaAlias(PrismaType):
67 kind: Kind = Kind.alias
68 to: str
70 @root_validator(pre=True)
71 @classmethod
72 def transform_to(cls, values: Dict[str, Any]) -> Dict[str, Any]:
73 if 'to' not in values and 'subtypes' in values: 73 ↛ 75line 73 didn't jump to line 75, because the condition on line 73 was never false
74 values['to'] = values['subtypes'][0].name
75 return values
78class Schema(BaseModel):
79 models: List['Model']
81 @classmethod
82 def from_data(cls, data: AnyData) -> 'Schema':
83 models = [Model(info=model) for model in data.dmmf.datamodel.models]
84 return cls(models=models)
86 def get_model(self, name: str) -> 'Model':
87 for model in self.models: 87 ↛ 91line 87 didn't jump to line 91, because the loop on line 87 didn't complete
88 if model.info.name == name:
89 return model
91 raise LookupError(f'Unknown model: {name}')
94class Model(BaseModel):
95 info: ModelInfo
97 if PYDANTIC_V2:
98 model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(cached_property,))
99 else:
101 class Config:
102 keep_untouched: Tuple[Type[Any], ...] = (cached_property,)
104 @cached_property
105 def where_unique(self) -> PrismaType:
106 info = self.info
107 model = info.name
108 variants: List[PrismaType] = [
109 PrismaDict(
110 total=True,
111 name=f'_{model}WhereUnique_{field.name}_Input',
112 fields={
113 field.name: field.python_type,
114 },
115 )
116 for field in info.scalar_fields
117 if field.is_id or field.is_unique
118 ]
120 for key in [info.compound_primary_key, *info.unique_indexes]:
121 if key is None:
122 continue
124 if isinstance(key, PrimaryKey):
125 name = f'_{model}CompoundPrimaryKey'
126 else:
127 name = f'_{model}Compound{key.name}Key'
129 variants.append(
130 PrismaDict(
131 name=name,
132 total=True,
133 fields={
134 key.name: f'{name}Inner',
135 },
136 subtypes=[
137 PrismaDict(
138 total=True,
139 name=f'{name}Inner',
140 fields={field.name: field.python_type for field in map(info.resolve_field, key.fields)},
141 )
142 ],
143 )
144 )
146 return PrismaType.from_variants(variants, name=f'{model}WhereUniqueInput')
148 @cached_property
149 def order_by(self) -> PrismaType:
150 model = self.info.name
151 variants: List[PrismaType] = [
152 PrismaDict(
153 name=f'_{model}_{field.name}_OrderByInput',
154 total=True,
155 fields={
156 field.name: 'SortOrder',
157 },
158 )
159 for field in self.info.scalar_fields
160 ]
161 # Full-text search relevance sorting
162 if data_ctx.get().datasources[0].active_provider in {'postgresql', 'mysql'}:
163 relevance_type = PrismaDict(
164 name=f'_{model}_RelevanceOrderByInput',
165 total=True,
166 fields={
167 '_relevance': f'_{model}_RelevanceInner',
168 },
169 subtypes=[
170 PrismaDict(
171 name=f'_{model}_RelevanceInner',
172 total=True,
173 fields={
174 'fields': f'List[{model}ScalarFieldKeys]',
175 'search': 'str',
176 'sort': 'SortOrder',
177 },
178 )
179 ],
180 )
181 variants.append(relevance_type)
182 return PrismaType.from_variants(variants, name=f'{model}OrderByInput')
185class ClientTypes(BaseModel):
186 transaction_isolation_level: Optional[PrismaEnum]
188 @classmethod
189 def from_data(cls, data: AnyData) -> 'ClientTypes':
190 enum_types = data.dmmf.prisma_schema.enum_types.prisma
192 return cls(
193 transaction_isolation_level=construct_enum_type(enum_types, name='TransactionIsolationLevel'),
194 )
197def construct_enum_type(dmmf_enum_types: List[DMMFEnumType], *, name: str) -> Optional[PrismaEnum]:
198 enum_type = next((t for t in dmmf_enum_types if t.name == name), None) 198 ↛ exitline 198 didn't finish the generator expression on line 198
199 if not enum_type: 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true
200 return None
202 return PrismaEnum(
203 name=name,
204 members=[(to_constant_case(str(value)), str(value)) for value in enum_type.values],
205 )
208model_rebuild(Schema)
209model_rebuild(PrismaType)
210model_rebuild(PrismaDict)
211model_rebuild(PrismaAlias)