Coverage for src/prisma/generator/schema.py: 95%

100 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from enum import Enum 

2from typing import Any, Dict, List, Type, Tuple, Union, Optional 

3from typing_extensions import ClassVar 

4 

5from pydantic import BaseModel 

6 

7from .utils import to_constant_case 

8from .models import Model as ModelInfo, AnyData, PrimaryKey, DMMFEnumType, data_ctx 

9from .._compat import ( 

10 PYDANTIC_V2, 

11 ConfigDict, 

12 model_rebuild, 

13 root_validator, 

14 cached_property, 

15) 

16 

17 

18class Kind(str, Enum): 

19 alias = 'alias' 

20 union = 'union' 

21 typeddict = 'typeddict' 

22 enum = 'enum' 

23 

24 

25class PrismaType(BaseModel): 

26 name: str 

27 kind: Kind 

28 subtypes: List['PrismaType'] = [] 

29 

30 @classmethod 

31 def from_variants(cls, variants: List['PrismaType'], **kwargs: Any) -> Union['PrismaUnion', 'PrismaAlias']: 

32 """Return either a `PrismaUnion` or a `PrismaAlias` depending on the number of variants""" 

33 if len(variants) > 1: 

34 return PrismaUnion(variants=variants, **kwargs) 

35 

36 return PrismaAlias(subtypes=variants, **kwargs) 

37 

38 

39class PrismaDict(PrismaType): 

40 kind: Kind = Kind.typeddict 

41 fields: Dict[str, str] 

42 total: bool 

43 

44 

45class PrismaUnion(PrismaType): 

46 kind: Kind = Kind.union 

47 variants: List[PrismaType] 

48 

49 @root_validator(pre=True) 

50 @classmethod 

51 def add_subtypes(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

52 # add all variants as subtypes so that we don't have to special 

53 # case rendering subtypes for unions 

54 if 'variants' in values: 54 ↛ 58line 54 didn't jump to line 58, because the condition on line 54 was never false

55 subtypes = values.get('subtypes', []) 

56 subtypes.extend(values['variants']) 

57 values['subtypes'] = subtypes 

58 return values 

59 

60 

61class PrismaEnum(PrismaType): 

62 kind: Kind = Kind.enum 

63 members: List[Tuple[str, str]] 

64 

65 

66class PrismaAlias(PrismaType): 

67 kind: Kind = Kind.alias 

68 to: str 

69 

70 @root_validator(pre=True) 

71 @classmethod 

72 def transform_to(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

73 if 'to' not in values and 'subtypes' in values: 73 ↛ 75line 73 didn't jump to line 75, because the condition on line 73 was never false

74 values['to'] = values['subtypes'][0].name 

75 return values 

76 

77 

78class Schema(BaseModel): 

79 models: List['Model'] 

80 

81 @classmethod 

82 def from_data(cls, data: AnyData) -> 'Schema': 

83 models = [Model(info=model) for model in data.dmmf.datamodel.models] 

84 return cls(models=models) 

85 

86 def get_model(self, name: str) -> 'Model': 

87 for model in self.models: 87 ↛ 91line 87 didn't jump to line 91, because the loop on line 87 didn't complete

88 if model.info.name == name: 

89 return model 

90 

91 raise LookupError(f'Unknown model: {name}') 

92 

93 

94class Model(BaseModel): 

95 info: ModelInfo 

96 

97 if PYDANTIC_V2: 

98 model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(cached_property,)) 

99 else: 

100 

101 class Config: 

102 keep_untouched: Tuple[Type[Any], ...] = (cached_property,) 

103 

104 @cached_property 

105 def where_unique(self) -> PrismaType: 

106 info = self.info 

107 model = info.name 

108 variants: List[PrismaType] = [ 

109 PrismaDict( 

110 total=True, 

111 name=f'_{model}WhereUnique_{field.name}_Input', 

112 fields={ 

113 field.name: field.python_type, 

114 }, 

115 ) 

116 for field in info.scalar_fields 

117 if field.is_id or field.is_unique 

118 ] 

119 

120 for key in [info.compound_primary_key, *info.unique_indexes]: 

121 if key is None: 

122 continue 

123 

124 if isinstance(key, PrimaryKey): 

125 name = f'_{model}CompoundPrimaryKey' 

126 else: 

127 name = f'_{model}Compound{key.name}Key' 

128 

129 variants.append( 

130 PrismaDict( 

131 name=name, 

132 total=True, 

133 fields={ 

134 key.name: f'{name}Inner', 

135 }, 

136 subtypes=[ 

137 PrismaDict( 

138 total=True, 

139 name=f'{name}Inner', 

140 fields={field.name: field.python_type for field in map(info.resolve_field, key.fields)}, 

141 ) 

142 ], 

143 ) 

144 ) 

145 

146 return PrismaType.from_variants(variants, name=f'{model}WhereUniqueInput') 

147 

148 @cached_property 

149 def order_by(self) -> PrismaType: 

150 model = self.info.name 

151 variants: List[PrismaType] = [ 

152 PrismaDict( 

153 name=f'_{model}_{field.name}_OrderByInput', 

154 total=True, 

155 fields={ 

156 field.name: 'SortOrder', 

157 }, 

158 ) 

159 for field in self.info.scalar_fields 

160 ] 

161 # Full-text search relevance sorting 

162 if data_ctx.get().datasources[0].active_provider in {'postgresql', 'mysql'}: 

163 relevance_type = PrismaDict( 

164 name=f'_{model}_RelevanceOrderByInput', 

165 total=True, 

166 fields={ 

167 '_relevance': f'_{model}_RelevanceInner', 

168 }, 

169 subtypes=[ 

170 PrismaDict( 

171 name=f'_{model}_RelevanceInner', 

172 total=True, 

173 fields={ 

174 'fields': f'List[{model}ScalarFieldKeys]', 

175 'search': 'str', 

176 'sort': 'SortOrder', 

177 }, 

178 ) 

179 ], 

180 ) 

181 variants.append(relevance_type) 

182 return PrismaType.from_variants(variants, name=f'{model}OrderByInput') 

183 

184 

185class ClientTypes(BaseModel): 

186 transaction_isolation_level: Optional[PrismaEnum] 

187 

188 @classmethod 

189 def from_data(cls, data: AnyData) -> 'ClientTypes': 

190 enum_types = data.dmmf.prisma_schema.enum_types.prisma 

191 

192 return cls( 

193 transaction_isolation_level=construct_enum_type(enum_types, name='TransactionIsolationLevel'), 

194 ) 

195 

196 

197def construct_enum_type(dmmf_enum_types: List[DMMFEnumType], *, name: str) -> Optional[PrismaEnum]: 

198 enum_type = next((t for t in dmmf_enum_types if t.name == name), None) 198 ↛ exitline 198 didn't finish the generator expression on line 198

199 if not enum_type: 199 ↛ 200line 199 didn't jump to line 200, because the condition on line 199 was never true

200 return None 

201 

202 return PrismaEnum( 

203 name=name, 

204 members=[(to_constant_case(str(value)), str(value)) for value in enum_type.values], 

205 ) 

206 

207 

208model_rebuild(Schema) 

209model_rebuild(PrismaType) 

210model_rebuild(PrismaDict) 

211model_rebuild(PrismaAlias)