Coverage for src/prisma/generator/models.py: 94%

640 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1import os 

2import sys 

3import enum 

4import textwrap 

5import importlib 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Dict, 

10 List, 

11 Type, 

12 Tuple, 

13 Union, 

14 Generic, 

15 TypeVar, 

16 ClassVar, 

17 Iterable, 

18 Iterator, 

19 NoReturn, 

20 Optional, 

21 cast, 

22) 

23from keyword import iskeyword 

24from pathlib import Path 

25from importlib import util as importlib_util, machinery 

26from itertools import chain 

27from contextvars import ContextVar 

28from importlib.abc import InspectLoader 

29from typing_extensions import Annotated, override 

30 

31import click 

32import pydantic 

33from pydantic.fields import PrivateAttr 

34 

35from .. import config 

36from .utils import Faker, Sampler, clean_multiline 

37from ..utils import DEBUG_GENERATOR, assert_never 

38from ..errors import UnsupportedListTypeError 

39from .._compat import ( 

40 PYDANTIC_V2, 

41 Field as FieldInfo, 

42 BaseConfig, 

43 ConfigDict, 

44 BaseSettings, 

45 GenericModel, 

46 PlainSerializer, 

47 BaseSettingsConfig, 

48 model_dict, 

49 model_parse, 

50 model_rebuild, 

51 root_validator, 

52 cached_property, 

53 field_validator, 

54) 

55from .._constants import QUERY_BUILDER_ALIASES 

56from ._dsl_parser import parse_schema_dsl 

57 

58__all__ = ( 

59 'AnyData', 

60 'PythonData', 

61 'DefaultData', 

62 'GenericData', 

63) 

64 

65_ModelT = TypeVar('_ModelT', bound=pydantic.BaseModel) 

66 

67# NOTE: this does not represent all the data that is passed by prisma 

68 

69ATOMIC_FIELD_TYPES = ['Int', 'BigInt', 'Float'] 

70 

71TYPE_MAPPING = { 

72 'String': '_str', 

73 'Bytes': "'fields.Base64'", 

74 'DateTime': 'datetime.datetime', 

75 'Boolean': '_bool', 

76 'Int': '_int', 

77 'Float': '_float', 

78 'BigInt': '_int', 

79 'Json': "'fields.Json'", 

80 'Decimal': 'decimal.Decimal', 

81} 

82FILTER_TYPES = [ 

83 'String', 

84 'Bytes', 

85 'DateTime', 

86 'Boolean', 

87 'Int', 

88 'BigInt', 

89 'Float', 

90 'Json', 

91 'Decimal', 

92] 

93RECURSIVE_TYPE_DEPTH_WARNING = """Some types are disabled by default due to being incompatible with Mypy, it is highly recommended 

94to use Pyright instead and configure Prisma Python to use recursive types. To re-enable certain types:""" 

95 

96RECURSIVE_TYPE_DEPTH_WARNING_DESC = """ 

97generator client { 

98 provider = "prisma-client-py" 

99 recursive_type_depth = -1 

100} 

101 

102If you need to use Mypy, you can also disable this message by explicitly setting the default value: 

103 

104generator client { 

105 provider = "prisma-client-py" 

106 recursive_type_depth = 5 

107} 

108 

109For more information see: https://prisma-client-py.readthedocs.io/en/stable/reference/limitations/#default-type-limitations 

110""" 

111 

112FAKER: Faker = Faker() 

113 

114 

115ConfigT = TypeVar('ConfigT', bound=pydantic.BaseModel) 

116 

117# Although we should just be able to access the config from the datamodel 

118# we have to do some validation that requires access to the config, this is difficult 

119# with heavily nested models as our current workaround only sets the datamodel context 

120# post-validation meaning we cannot access it in validators. To get around this we have 

121# a separate config context. 

122# TODO: better solution 

123data_ctx: ContextVar['AnyData'] = ContextVar('data_ctx') 

124config_ctx: ContextVar['Config'] = ContextVar('config_ctx') 

125 

126 

127def get_datamodel() -> 'Datamodel': 

128 return data_ctx.get().dmmf.datamodel 

129 

130 

131# typed to ensure the caller has to handle the cases where: 

132# - a custom generator config is being used 

133# - the config is invalid and therefore could not be set 

134def get_config() -> Union[None, pydantic.BaseModel, 'Config']: 

135 return config_ctx.get(None) 

136 

137 

138def get_list_types() -> Iterable[Tuple[str, str]]: 

139 # WARNING: do not edit this function without also editing Field.is_supported_scalar_list_type() 

140 return chain( 

141 ((t, TYPE_MAPPING[t]) for t in FILTER_TYPES), 

142 ((enum.name, f"'enums.{enum.name}'") for enum in get_datamodel().enums), 

143 ) 

144 

145 

146def sql_param(num: int = 1) -> str: 

147 # TODO: add case for sqlserver 

148 active_provider = data_ctx.get().datasources[0].active_provider 

149 if active_provider == 'postgresql': 

150 return f'${num}' 

151 

152 # TODO: test 

153 if active_provider == 'mongodb': # pragma: no cover 

154 raise RuntimeError('no-op') 

155 

156 # SQLite and MySQL use this style so just default to it 

157 return '?' 

158 

159 

160def raise_err(msg: str) -> NoReturn: 

161 raise TemplateError(msg) 

162 

163 

164def type_as_string(typ: str) -> str: 

165 """Ensure a type string is wrapped with a string, e.g. 

166 

167 enums.Role -> 'enums.Role' 

168 """ 

169 # TODO: use this function internally in this module 

170 if not typ.startswith("'") and not typ.startswith('"'): 

171 return f"'{typ}'" 

172 return typ 

173 

174 

175def format_documentation(doc: str, indent: int = 4) -> str: 

176 """Format a schema comment by indenting nested lines, e.g. 

177 

178 '''Foo 

179 Bar''' 

180 

181 Becomes 

182 

183 '''Foo 

184 Bar 

185 ''' 

186 """ 

187 if not doc: 187 ↛ 189line 187 didn't jump to line 189, because the condition on line 187 was never true

188 # empty string, nothing to do 

189 return doc 

190 

191 prefix = ' ' * indent 

192 first, *rest = doc.splitlines() 

193 return '\n'.join( 

194 [ 

195 first, 

196 *[textwrap.indent(line, prefix) for line in rest], 

197 prefix, 

198 ] 

199 ) 

200 

201 

202def _module_spec_serializer(spec: machinery.ModuleSpec) -> str: 

203 assert spec.origin is not None, 'Cannot serialize module with no origin' 

204 return spec.origin 

205 

206 

207def _pathlib_serializer(path: Path) -> str: 

208 return str(path.absolute()) 

209 

210 

211def _recursive_type_depth_factory() -> int: 

212 click.echo( 

213 click.style( 

214 f'\n{RECURSIVE_TYPE_DEPTH_WARNING}', 

215 fg='yellow', 

216 ) 

217 ) 

218 click.echo(f'{RECURSIVE_TYPE_DEPTH_WARNING_DESC}\n') 

219 return 5 

220 

221 

222class BaseModel(pydantic.BaseModel): 

223 if PYDANTIC_V2: 

224 model_config: ClassVar[ConfigDict] = ConfigDict( 

225 arbitrary_types_allowed=True, 

226 ignored_types=(cached_property,), 

227 ) 

228 else: 

229 

230 class Config(BaseConfig): 

231 arbitrary_types_allowed: bool = True 

232 json_encoders: Dict[Type[Any], Any] = { 

233 Path: _pathlib_serializer, 

234 machinery.ModuleSpec: _module_spec_serializer, 

235 } 

236 keep_untouched: Tuple[Type[Any], ...] = (cached_property,) 

237 

238 

239class InterfaceChoices(str, enum.Enum): 

240 sync = 'sync' 

241 asyncio = 'asyncio' 

242 

243 

244class EngineType(str, enum.Enum): 

245 binary = 'binary' 

246 library = 'library' 

247 dataproxy = 'dataproxy' 

248 

249 @override 

250 def __str__(self) -> str: 

251 return self.value 

252 

253 

254class Module(BaseModel): 

255 if TYPE_CHECKING: 

256 spec: machinery.ModuleSpec 

257 else: 

258 if PYDANTIC_V2: 

259 spec: Annotated[ 

260 machinery.ModuleSpec, 

261 PlainSerializer(lambda x: _module_spec_serializer(x), return_type=str), 

262 ] 

263 else: 

264 spec: machinery.ModuleSpec 

265 

266 if PYDANTIC_V2: 

267 model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) 

268 else: 

269 

270 class Config(BaseModel.Config): 

271 arbitrary_types_allowed: bool = True 

272 

273 # for some reason this is needed in Pydantic v2 

274 @root_validator(pre=True, skip_on_failure=True) 

275 @classmethod 

276 def partial_type_generator_converter(cls, values: object) -> Any: 

277 if isinstance(values, str): 

278 return {'spec': values} 

279 return values 

280 

281 @field_validator('spec', pre=True, allow_reuse=True) 

282 @classmethod 

283 def spec_validator(cls, value: Optional[str]) -> machinery.ModuleSpec: 

284 spec: Optional[machinery.ModuleSpec] = None 

285 

286 # TODO: this should really work based off of the schema path 

287 # and this should suport checking just partial_types.py if we are in a `prisma` dir 

288 if value is None: 

289 value = 'prisma/partial_types.py' 

290 

291 path = Path.cwd().joinpath(value) 

292 if path.exists(): 

293 spec = importlib_util.spec_from_file_location('prisma.partial_type_generator', value) 

294 elif value.startswith('.'): 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true

295 raise ValueError(f'No file found at {value} and relative imports are not allowed.') 

296 else: 

297 try: 

298 spec = importlib_util.find_spec(value) 

299 except ModuleNotFoundError: 

300 spec = None 

301 

302 if spec is None: 

303 raise ValueError(f'Could not find a python file or module at {value}') 

304 

305 return spec 

306 

307 def run(self) -> None: 

308 importlib.invalidate_caches() 

309 mod = importlib_util.module_from_spec(self.spec) 

310 loader = self.spec.loader 

311 assert loader is not None, 'Expected an import loader to exist.' 

312 assert isinstance(loader, InspectLoader), f'Cannot execute module from loader type: {type(loader)}' 

313 

314 try: 

315 loader.exec_module(mod) 

316 except Exception as exc: 

317 raise PartialTypeGeneratorError() from exc 

318 

319 

320class GenericData(GenericModel, Generic[ConfigT]): 

321 """Root model for the data that prisma provides to the generator. 

322 

323 WARNING: only one instance of this class may exist at any given time and 

324 instances should only be constructed using the Data.parse_obj() method 

325 """ 

326 

327 datamodel: str 

328 version: str 

329 generator: 'Generator[ConfigT]' 

330 dmmf: 'DMMF' = FieldInfo(alias='dmmf') 

331 schema_path: Path = FieldInfo(alias='schemaPath') 

332 datasources: List['Datasource'] = FieldInfo(alias='datasources') 

333 other_generators: List['Generator[_ModelAllowAll]'] = FieldInfo(alias='otherGenerators') 

334 binary_paths: 'BinaryPaths' = FieldInfo(alias='binaryPaths', default_factory=lambda: BinaryPaths()) 334 ↛ exitline 334 didn't run the lambda on line 334

335 

336 if PYDANTIC_V2: 

337 

338 @root_validator(pre=False) 

339 def _set_ctx(self: _ModelT) -> _ModelT: 

340 data_ctx.set(cast('GenericData[ConfigT]', self)) 

341 return self 

342 

343 else: 

344 

345 @classmethod 

346 @override 

347 def parse_obj(cls, obj: Any) -> 'GenericData[ConfigT]': 

348 data = super().parse_obj(obj) # pyright: ignore[reportDeprecated] 

349 data_ctx.set(data) 

350 return data 

351 

352 def to_params(self) -> Dict[str, Any]: 

353 """Get the parameters that should be sent to Jinja templates""" 

354 params = vars(self) 

355 params['type_schema'] = Schema.from_data(self) 

356 params['client_types'] = ClientTypes.from_data(self) 

357 

358 # add utility functions 

359 for func in [ 

360 sql_param, 

361 raise_err, 

362 type_as_string, 

363 get_list_types, 

364 clean_multiline, 

365 format_documentation, 

366 model_dict, 

367 ]: 

368 params[func.__name__] = func 

369 

370 return params 

371 

372 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True) 

373 @classmethod 

374 def validate_version(cls, values: Dict[Any, Any]) -> Dict[Any, Any]: 

375 # TODO: test this 

376 version = values.get('version') 

377 if not DEBUG_GENERATOR and version != config.expected_engine_version: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true

378 raise ValueError( 

379 f'Prisma Client Python expected Prisma version: {config.expected_engine_version} ' 

380 f'but got: {version}\n' 

381 ' If this is intentional, set the PRISMA_PY_DEBUG_GENERATOR environment ' 

382 'variable to 1 and try again.\n' 

383 f' If you are using the Node CLI then you must switch to v{config.prisma_version}, e.g. ' 

384 f'npx prisma@{config.prisma_version} generate\n' 

385 ' or generate the client using the Python CLI, e.g. python3 -m prisma generate' 

386 ) 

387 return values 

388 

389 

390class BinaryPaths(BaseModel): 

391 """This class represents the paths to engine binaries. 

392 

393 Each property in this class is a mapping of platform name to absolute path, for example: 

394 

395 ```py 

396 # This is what will be set on an M1 chip if there are no other `binaryTargets` set 

397 binary_paths.query_engine == { 

398 'darwin-arm64': '/Users/robert/.cache/prisma-python/binaries/3.13.0/efdf9b1183dddfd4258cd181a72125755215ab7b/node_modules/prisma/query-engine-darwin-arm64' 

399 } 

400 ``` 

401 

402 This is only available if the generator explicitly requests them using the `requires_engines` manifest property. 

403 """ 

404 

405 query_engine: Dict[str, str] = FieldInfo( 

406 default_factory=dict, 

407 alias='queryEngine', 

408 ) 

409 introspection_engine: Dict[str, str] = FieldInfo( 

410 default_factory=dict, 

411 alias='introspectionEngine', 

412 ) 

413 migration_engine: Dict[str, str] = FieldInfo( 

414 default_factory=dict, 

415 alias='migrationEngine', 

416 ) 

417 libquery_engine: Dict[str, str] = FieldInfo( 

418 default_factory=dict, 

419 alias='libqueryEngine', 

420 ) 

421 prisma_format: Dict[str, str] = FieldInfo( 

422 default_factory=dict, 

423 alias='prismaFmt', 

424 ) 

425 

426 if PYDANTIC_V2: 

427 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow') 

428 else: 

429 

430 class Config(BaseModel.Config): # pyright: ignore[reportDeprecated] 

431 extra: Any = ( 

432 pydantic.Extra.allow # pyright: ignore[reportDeprecated] 

433 ) 

434 

435 

436class Datasource(BaseModel): 

437 # TODO: provider enums 

438 name: str 

439 provider: str 

440 active_provider: str = FieldInfo(alias='activeProvider') 

441 url: 'OptionalValueFromEnvVar' 

442 

443 

444class Generator(GenericModel, Generic[ConfigT]): 

445 name: str 

446 output: 'ValueFromEnvVar' 

447 provider: 'OptionalValueFromEnvVar' 

448 config: ConfigT 

449 binary_targets: List['ValueFromEnvVar'] = FieldInfo(alias='binaryTargets') 

450 preview_features: List[str] = FieldInfo(alias='previewFeatures') 

451 

452 @field_validator('binary_targets') 

453 @classmethod 

454 def warn_binary_targets(cls, targets: List['ValueFromEnvVar']) -> List['ValueFromEnvVar']: 

455 # Prisma by default sends one binary target which is the current platform. 

456 if len(targets) > 1: 

457 click.echo( 

458 click.style( 

459 'Warning: ' + 'The binaryTargets option is not officially supported by Prisma Client Python.', 

460 fg='yellow', 

461 ), 

462 file=sys.stdout, 

463 ) 

464 

465 return targets 

466 

467 def has_preview_feature(self, feature: str) -> bool: 

468 return feature in self.preview_features 

469 

470 

471class ValueFromEnvVar(BaseModel): 

472 value: str 

473 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar') 

474 

475 

476class OptionalValueFromEnvVar(BaseModel): 

477 value: Optional[str] = None 

478 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar') 

479 

480 def resolve(self) -> str: 

481 value = self.value 

482 if value is not None: 

483 return value 

484 

485 env_var = self.from_env_var 

486 assert env_var is not None, 'from_env_var should not be None' 

487 value = os.environ.get(env_var) 

488 if value is None: 488 ↛ 489line 488 didn't jump to line 489, because the condition on line 488 was never true

489 raise RuntimeError(f'Environment variable not found: {env_var}') 

490 

491 return value 

492 

493 

494class Config(BaseSettings): 

495 """Custom generator config options.""" 

496 

497 interface: InterfaceChoices = FieldInfo(default=InterfaceChoices.asyncio, env='PRISMA_PY_CONFIG_INTERFACE') 

498 partial_type_generator: Optional[Module] = FieldInfo(default=None, env='PRISMA_PY_CONFIG_PARTIAL_TYPE_GENERATOR') 

499 recursive_type_depth: int = FieldInfo( 

500 default_factory=_recursive_type_depth_factory, 

501 env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH', 

502 ) 

503 engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE') 

504 

505 # this should be a list of experimental features 

506 # https://github.com/prisma/prisma/issues/12442 

507 enable_experimental_decimal: bool = FieldInfo(default=False, env='PRISMA_PY_CONFIG_ENABLE_EXPERIMENTAL_DECIMAL') 

508 

509 # this seems to be the only good method for setting the contextvar as 

510 # we don't control the actual construction of the object like we do for 

511 # the Data model. 

512 # we do not expose this to type checkers so that the generated __init__ 

513 # signature is preserved. 

514 if not TYPE_CHECKING: 

515 

516 def __init__(self, **kwargs: object) -> None: 

517 super().__init__(**kwargs) 

518 config_ctx.set(self) 

519 

520 if PYDANTIC_V2: 

521 model_config: ClassVar[ConfigDict] = ConfigDict( 

522 extra='forbid', 

523 use_enum_values=True, 

524 populate_by_name=True, 

525 ) 

526 else: 

527 if not TYPE_CHECKING: 

528 

529 class Config(BaseSettingsConfig): 

530 extra: pydantic.Extra = pydantic.Extra.forbid 

531 use_enum_values: bool = True 

532 env_prefix: str = 'prisma_py_config_' 

533 allow_population_by_field_name: bool = True 

534 

535 @classmethod 

536 def customise_sources(cls, init_settings, env_settings, file_secret_settings): 

537 # prioritise env settings over init settings 

538 return env_settings, init_settings, file_secret_settings 

539 

540 @root_validator(pre=True, skip_on_failure=True) 

541 @classmethod 

542 def transform_engine_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

543 # prioritise env variable over schema option 

544 engine_type = os.environ.get('PRISMA_CLIENT_ENGINE_TYPE') 

545 if engine_type is None: 545 ↛ 549line 545 didn't jump to line 549, because the condition on line 545 was never false

546 engine_type = values.get('engineType') 

547 

548 # only add engine_type if it is present 

549 if engine_type is not None: 

550 values['engine_type'] = engine_type 

551 values.pop('engineType', None) 

552 

553 return values 

554 

555 @root_validator(pre=True, skip_on_failure=True) 

556 @classmethod 

557 def removed_http_option_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

558 http = values.get('http') 

559 if http is not None: 

560 if http in {'aiohttp', 'httpx-async'}: 

561 option = 'asyncio' 

562 elif http in {'requests', 'httpx-sync'}: 

563 option = 'sync' 

564 else: # pragma: no cover 

565 # invalid http option, let pydantic handle the error 

566 return values 

567 

568 raise ValueError( 

569 'The http option has been removed in favour of the interface option.\n' 

570 ' Please remove the http option from your Prisma schema and replace it with:\n' 

571 f' interface = "{option}"' 

572 ) 

573 return values 

574 

575 if PYDANTIC_V2: 

576 

577 @root_validator(pre=True, skip_on_failure=True) 

578 @classmethod 

579 def partial_type_generator_converter(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

580 # ensure env resolving happens 

581 values = cast(Dict[str, Any], cls.root_validator(values)) # type: ignore 

582 

583 value = values.get('partial_type_generator') 

584 

585 try: 

586 values['partial_type_generator'] = Module( 

587 spec=value # pyright: ignore[reportArgumentType] 

588 ) 

589 except ValueError: 

590 if value is None: 

591 # no config value passed and the default location was not found 

592 return values 

593 raise 

594 

595 return values 

596 

597 else: 

598 

599 @field_validator('partial_type_generator', pre=True, always=True, allow_reuse=True) 

600 @classmethod 

601 def _partial_type_generator_converter(cls, value: Optional[str]) -> Optional[Module]: 

602 try: 

603 return Module( 

604 spec=value # pyright: ignore[reportArgumentType] 

605 ) 

606 except ValueError: 

607 if value is None: 

608 # no config value passed and the default location was not found 

609 return None 

610 raise 

611 

612 @field_validator('recursive_type_depth', always=True, allow_reuse=True) 

613 @classmethod 

614 def recursive_type_depth_validator(cls, value: int) -> int: 

615 if value < -1 or value in {0, 1}: 

616 raise ValueError('Value must equal -1 or be greater than 1.') 

617 return value 

618 

619 @field_validator('engine_type', always=True, allow_reuse=True) 

620 @classmethod 

621 def engine_type_validator(cls, value: EngineType) -> EngineType: 

622 if value == EngineType.binary: 

623 return value 

624 elif value == EngineType.dataproxy: # pragma: no cover 

625 raise ValueError('Prisma Client Python does not support the Prisma Data Proxy yet.') 

626 elif value == EngineType.library: # pragma: no cover 

627 raise ValueError('Prisma Client Python does not support native engine bindings yet.') 

628 else: # pragma: no cover 

629 assert_never(value) 

630 

631 

632class DMMFEnumType(BaseModel): 

633 name: str 

634 values: List[object] 

635 

636 

637class DMMFEnumTypes(BaseModel): 

638 prisma: List[DMMFEnumType] 

639 

640 

641class PrismaSchema(BaseModel): 

642 enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes') 

643 

644 

645class DMMF(BaseModel): 

646 datamodel: 'Datamodel' 

647 prisma_schema: PrismaSchema = FieldInfo(alias='schema') 

648 

649 

650class Datamodel(BaseModel): 

651 enums: List['Enum'] 

652 models: List['Model'] 

653 

654 # not implemented yet 

655 types: List[object] 

656 

657 @field_validator('types') 

658 @classmethod 

659 def no_composite_types_validator(cls, types: List[object]) -> object: 

660 if types: 

661 raise ValueError( 

662 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314' 

663 ) 

664 

665 return types 

666 

667 

668class Enum(BaseModel): 

669 name: str 

670 db_name: Optional[str] = FieldInfo(alias='dbName') 

671 values: List['EnumValue'] 

672 

673 

674class EnumValue(BaseModel): 

675 name: str 

676 db_name: Optional[str] = FieldInfo(alias='dbName') 

677 

678 

679class ModelExtension(BaseModel): 

680 instance_name: Optional[str] = None 

681 

682 @field_validator('instance_name') 

683 @classmethod 

684 def instance_name_validator(cls, name: Optional[str]) -> Optional[str]: 

685 if not name: 685 ↛ 686line 685 didn't jump to line 686, because the condition on line 685 was never true

686 return name 

687 

688 if not name.isidentifier(): 

689 raise ValueError(f'Custom Model instance_name "{name}" is not a valid Python identifier') 

690 

691 return name 

692 

693 

694class Model(BaseModel): 

695 name: str 

696 documentation: Optional[str] = None 

697 db_name: Optional[str] = FieldInfo(alias='dbName') 

698 is_generated: bool = FieldInfo(alias='isGenerated') 

699 compound_primary_key: Optional['PrimaryKey'] = FieldInfo(alias='primaryKey') 

700 unique_indexes: List['UniqueIndex'] = FieldInfo(alias='uniqueIndexes') 

701 all_fields: List['Field'] = FieldInfo(alias='fields') 

702 

703 # stores the parsed DSL, not an actual field defined by prisma 

704 extension: Optional[ModelExtension] = None 

705 

706 _sampler: Sampler = PrivateAttr() 

707 

708 def __init__(self, **data: Any) -> None: 

709 super().__init__(**data) 

710 self._sampler = Sampler(self) 

711 

712 @root_validator(pre=True, allow_reuse=True) 

713 @classmethod 

714 def validate_dsl_extension(cls, values: Dict[Any, Any]) -> Dict[Any, Any]: 

715 documentation = values.get('documentation') 

716 if not documentation: 

717 return values 

718 

719 parsed = parse_schema_dsl(documentation) 

720 if parsed['type'] == 'invalid': 720 ↛ 721line 720 didn't jump to line 721, because the condition on line 720 was never true

721 raise ValueError(parsed['error']) 

722 

723 if parsed['type'] == 'ok': 723 ↛ 726line 723 didn't jump to line 726, because the condition on line 723 was never false

724 values['extension'] = model_parse(ModelExtension, parsed['value']['arguments']) 

725 

726 return values 

727 

728 @field_validator('name') 

729 @classmethod 

730 def name_validator(cls, name: str) -> str: 

731 if iskeyword(name): 

732 raise ValueError( 

733 f'Model name "{name}" shadows a Python keyword; ' 

734 f'use a different model name with \'@@map("{name}")\'.' 

735 ) 

736 

737 if iskeyword(name.lower()): 

738 raise ValueError( 

739 f'Model name "{name}" results in a client property that shadows a Python keyword; ' 

740 f'use a different model name with \'@@map("{name}")\'.' 

741 ) 

742 

743 return name 

744 

745 @property 

746 def related_models(self) -> Iterator['Model']: 

747 models = get_datamodel().models 

748 for field in self.relational_fields: 

749 for model in models: 

750 if field.type == model.name: 

751 yield model 

752 

753 @property 

754 def relational_fields(self) -> Iterator['Field']: 

755 for field in self.all_fields: 

756 if field.is_relational: 

757 yield field 

758 

759 @property 

760 def scalar_fields(self) -> Iterator['Field']: 

761 for field in self.all_fields: 

762 if not field.is_relational: 

763 yield field 

764 

765 @property 

766 def atomic_fields(self) -> Iterator['Field']: 

767 for field in self.all_fields: 

768 if field.type in ATOMIC_FIELD_TYPES: 

769 yield field 

770 

771 @property 

772 def required_array_fields(self) -> Iterator['Field']: 

773 for field in self.all_fields: 

774 if field.is_list and not field.relation_name and field.is_required: 

775 yield field 

776 

777 # TODO: support combined unique constraints 

778 @cached_property 

779 def id_field(self) -> Optional['Field']: 

780 """Find a field that can be passed to the model's `WhereUnique` filter""" 

781 for field in self.scalar_fields: # pragma: no branch 

782 if field.is_id or field.is_unique: 

783 return field 

784 return None 

785 

786 @property 

787 def has_relational_fields(self) -> bool: 

788 try: 

789 next(self.relational_fields) 

790 except StopIteration: 

791 return False 

792 else: 

793 return True 

794 

795 @property 

796 def instance_name(self) -> str: 

797 """The name of this model in the generated client class, e.g. 

798 

799 `User` -> `Prisma().user` 

800 """ 

801 if self.extension and self.extension.instance_name: 

802 return self.extension.instance_name 

803 

804 return self.name.lower() 

805 

806 @property 

807 def plural_name(self) -> str: 

808 name = self.instance_name 

809 if name.endswith('s'): 

810 return name 

811 return f'{name}s' 

812 

813 def resolve_field(self, name: str) -> 'Field': 

814 for field in self.all_fields: 814 ↛ 818line 814 didn't jump to line 818, because the loop on line 814 didn't complete

815 if field.name == name: 

816 return field 

817 

818 raise LookupError(f'Could not find a field with name: {name}') 

819 

820 def sampler(self) -> Sampler: 

821 return self._sampler 

822 

823 

824class Constraint(BaseModel): 

825 name: str 

826 fields: List[str] 

827 

828 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True) 

829 @classmethod 

830 def resolve_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

831 name = values.get('name') 

832 if isinstance(name, str): 

833 return values 

834 

835 values['name'] = '_'.join(values['fields']) 

836 return values 

837 

838 

839class PrimaryKey(Constraint): 

840 pass 

841 

842 

843class UniqueIndex(Constraint): 

844 pass 

845 

846 

847class Field(BaseModel): 

848 name: str 

849 documentation: Optional[str] = None 

850 

851 # TODO: switch to enums 

852 kind: str 

853 type: str 

854 

855 is_id: bool = FieldInfo(alias='isId') 

856 is_list: bool = FieldInfo(alias='isList') 

857 is_unique: bool = FieldInfo(alias='isUnique') 

858 is_required: bool = FieldInfo(alias='isRequired') 

859 is_read_only: bool = FieldInfo(alias='isReadOnly') 

860 is_generated: bool = FieldInfo(alias='isGenerated') 

861 is_updated_at: bool = FieldInfo(alias='isUpdatedAt') 

862 

863 default: Optional[Union['DefaultValue', object, List[object]]] = None 

864 has_default_value: bool = FieldInfo(alias='hasDefaultValue') 

865 

866 relation_name: Optional[str] = FieldInfo(alias='relationName', default=None) 

867 relation_on_delete: Optional[str] = FieldInfo(alias='relationOnDelete', default=None) 

868 relation_to_fields: Optional[List[str]] = FieldInfo( 

869 alias='relationToFields', 

870 default=None, 

871 ) 

872 relation_from_fields: Optional[List[str]] = FieldInfo( 

873 alias='relationFromFields', 

874 default=None, 

875 ) 

876 

877 _last_sampled: Optional[str] = PrivateAttr() 

878 

879 @root_validator(pre=True, skip_on_failure=True) 

880 @classmethod 

881 def scalar_type_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

882 kind = values.get('kind') 

883 type_ = values.get('type') 

884 

885 if kind == 'scalar': 

886 if type_ is not None and type_ not in TYPE_MAPPING: 886 ↛ 887line 886 didn't jump to line 887, because the condition on line 886 was never true

887 raise ValueError(f'Unsupported scalar field type: {type_}') 

888 

889 return values 

890 

891 @field_validator('type') 

892 @classmethod 

893 def experimental_decimal_validator(cls, typ: str) -> str: 

894 if typ == 'Decimal': 

895 config = get_config() 

896 

897 # skip validating the experimental flag if we are 

898 # being called from a custom generator 

899 if isinstance(config, Config) and not config.enable_experimental_decimal: 

900 raise ValueError( 

901 'Support for the Decimal type is experimental\n' 

902 ' As such you must set the `enable_experimental_decimal` config flag to true\n' 

903 ' for more information see: https://github.com/RobertCraigie/prisma-client-py/issues/106' 

904 ) 

905 

906 return typ 

907 

908 @field_validator('name') 

909 @classmethod 

910 def name_validator(cls, name: str) -> str: 

911 if getattr(BaseModel, name, None): 

912 raise ValueError( 

913 f'Field name "{name}" shadows a BaseModel attribute; ' 

914 f'use a different field name with \'@map("{name}")\'.' 

915 ) 

916 

917 if iskeyword(name): 

918 raise ValueError( 

919 f'Field name "{name}" shadows a Python keyword; ' f'use a different field name with \'@map("{name}")\'.' 

920 ) 

921 

922 if name == 'prisma': 

923 raise ValueError( 

924 'Field name "prisma" shadows a Prisma Client Python method; ' 

925 'use a different field name with \'@map("prisma")\'.' 

926 ) 

927 

928 if name in QUERY_BUILDER_ALIASES: 

929 raise ValueError( 

930 f'Field name "{name}" shadows an internal keyword; ' 

931 f'use a different field name with \'@map("{name}")\'' 

932 ) 

933 

934 return name 

935 

936 # TODO: cache the properties 

937 @property 

938 def python_type(self) -> str: 

939 type_ = self._actual_python_type 

940 if self.is_list: 

941 return f'List[{type_}]' 

942 return type_ 

943 

944 @property 

945 def python_type_as_string(self) -> str: 

946 type_ = self._actual_python_type 

947 if self.is_list: 

948 type_ = type_.replace("'", "\\'") 

949 return f"'List[{type_}]'" 

950 

951 if not type_.startswith("'"): 

952 type_ = f"'{type_}'" 

953 

954 return type_ 

955 

956 @property 

957 def _actual_python_type(self) -> str: 

958 if self.kind == 'enum': 

959 return f"'enums.{self.type}'" 

960 

961 if self.kind == 'object': 

962 return f"'models.{self.type}'" 

963 

964 try: 

965 return TYPE_MAPPING[self.type] 

966 except KeyError as exc: 

967 # TODO: handle this better 

968 raise RuntimeError( 

969 f'Could not parse {self.name} due to unknown type: {self.type}', 

970 ) from exc 

971 

972 @property 

973 def create_input_type(self) -> str: 

974 if self.kind != 'object': 

975 return self.python_type 

976 

977 if self.is_list: 

978 return f"'{self.type}CreateManyNestedWithoutRelationsInput'" 

979 

980 return f"'{self.type}CreateNestedWithoutRelationsInput'" 

981 

982 @property 

983 def where_input_type(self) -> str: 

984 typ = self.type 

985 if self.is_relational: 

986 if self.is_list: 

987 return f"'{typ}ListRelationFilter'" 

988 return f"'{typ}RelationFilter'" 

989 

990 if self.is_list: 

991 self.check_supported_scalar_list_type() 

992 return f"'types.{typ}ListFilter'" 

993 

994 if typ in FILTER_TYPES: 

995 if self.is_optional: 

996 return f"Union[None, {self._actual_python_type}, 'types.{typ}Filter']" 

997 return f"Union[{self._actual_python_type}, 'types.{typ}Filter']" 

998 

999 return self.python_type 

1000 

1001 @property 

1002 def where_aggregates_input_type(self) -> str: 

1003 if self.is_relational: # pragma: no cover 

1004 raise RuntimeError('This type is not valid for relational fields') 

1005 

1006 typ = self.type 

1007 if typ in FILTER_TYPES: 

1008 return f"Union[{self._actual_python_type}, 'types.{typ}WithAggregatesFilter']" 

1009 return self.python_type 

1010 

1011 @property 

1012 def relational_args_type(self) -> str: 

1013 if self.is_list: 

1014 return f'FindMany{self.type}Args' 

1015 return f'{self.type}Args' 

1016 

1017 @property 

1018 def required_on_create(self) -> bool: 

1019 return ( 

1020 self.is_required 

1021 and not self.is_updated_at 

1022 and not self.has_default_value 

1023 and not self.relation_name 

1024 and not self.is_list 

1025 ) 

1026 

1027 @property 

1028 def is_optional(self) -> bool: 

1029 return not (self.is_required and not self.relation_name) 

1030 

1031 @property 

1032 def is_relational(self) -> bool: 

1033 return self.relation_name is not None 

1034 

1035 @property 

1036 def is_atomic(self) -> bool: 

1037 return self.type in ATOMIC_FIELD_TYPES 

1038 

1039 @property 

1040 def is_number(self) -> bool: 

1041 return self.type in {'Int', 'BigInt', 'Float'} 

1042 

1043 def maybe_optional(self, typ: str) -> str: 

1044 """Wrap the given type string within `Optional` if applicable""" 

1045 if self.is_required or self.is_relational: 

1046 return typ 

1047 return f'Optional[{typ}]' 

1048 

1049 def get_update_input_type(self) -> str: 

1050 if self.kind == 'object': 

1051 if self.is_list: 

1052 return f"'{self.type}UpdateManyWithoutRelationsInput'" 

1053 return f"'{self.type}UpdateOneWithoutRelationsInput'" 

1054 

1055 if self.is_list: 

1056 self.check_supported_scalar_list_type() 

1057 return f"'types.{self.type}ListUpdate'" 

1058 

1059 if self.is_atomic: 

1060 return f'Union[Atomic{self.type}Input, {self.python_type}]' 

1061 

1062 return self.python_type 

1063 

1064 def check_supported_scalar_list_type(self) -> None: 

1065 if self.type not in FILTER_TYPES and self.kind != 'enum': # pragma: no branch 

1066 raise UnsupportedListTypeError(self.type) 

1067 

1068 def get_relational_model(self) -> Optional['Model']: 

1069 if not self.is_relational: 1069 ↛ 1070line 1069 didn't jump to line 1070, because the condition on line 1069 was never true

1070 return None 

1071 

1072 name = self.type 

1073 for model in get_datamodel().models: 1073 ↛ 1076line 1073 didn't jump to line 1076, because the loop on line 1073 didn't complete

1074 if model.name == name: 

1075 return model 

1076 return None 

1077 

1078 def get_corresponding_enum(self) -> Optional['Enum']: 

1079 typ = self.type 

1080 for enum in get_datamodel().enums: 

1081 if enum.name == typ: 1081 ↛ 1080line 1081 didn't jump to line 1080, because the condition on line 1081 was never false

1082 return enum 

1083 return None # pragma: no cover 

1084 

1085 def get_sample_data(self, *, increment: bool = True) -> str: 

1086 # returning the same data that was last sampled is useful 

1087 # for documenting methods like upsert() where data is duplicated 

1088 if not increment and self._last_sampled is not None: 

1089 return self._last_sampled 

1090 

1091 sampled = self._get_sample_data() 

1092 if self.is_list: 

1093 sampled = f'[{sampled}]' 

1094 

1095 self._last_sampled = sampled 

1096 return sampled 

1097 

1098 def _get_sample_data(self) -> str: 

1099 if self.is_relational: # pragma: no cover 

1100 raise RuntimeError('Data sampling for relational fields not supported yet') 

1101 

1102 if self.kind == 'enum': 

1103 enum = self.get_corresponding_enum() 

1104 assert enum is not None, self.type 

1105 return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}' 

1106 

1107 typ = self.type 

1108 if typ == 'Boolean': 

1109 return str(FAKER.boolean()) 

1110 elif typ == 'Int': 

1111 return str(FAKER.integer()) 

1112 elif typ == 'String': 

1113 return f"'{FAKER.string()}'" 

1114 elif typ == 'Float': 

1115 return f'{FAKER.integer()}.{FAKER.integer() // 10000}' 

1116 elif typ == 'BigInt': # pragma: no cover 

1117 return str(FAKER.integer() * 12) 

1118 elif typ == 'DateTime': 

1119 # TODO: random dates 

1120 return 'datetime.datetime.utcnow()' 

1121 elif typ == 'Json': 1121 ↛ 1122line 1121 didn't jump to line 1122, because the condition on line 1121 was never true

1122 return f"Json({{'{FAKER.string()}': True}})" 

1123 elif typ == 'Bytes': 1123 ↛ 1125line 1123 didn't jump to line 1125, because the condition on line 1123 was never false

1124 return f"Base64.encode(b'{FAKER.string()}')" 

1125 elif typ == 'Decimal': 

1126 return f"Decimal('{FAKER.integer()}.{FAKER.integer() // 10000}')" 

1127 else: # pragma: no cover 

1128 raise RuntimeError(f'Sample data not supported for {typ} yet') 

1129 

1130 

1131class DefaultValue(BaseModel): 

1132 args: Any = None 

1133 name: str 

1134 

1135 

1136class _EmptyModel(BaseModel): 

1137 if PYDANTIC_V2: 

1138 model_config: ClassVar[ConfigDict] = ConfigDict(extra='forbid') 

1139 elif not TYPE_CHECKING: 

1140 

1141 class Config(BaseModel.Config): 

1142 extra: pydantic.Extra = pydantic.Extra.forbid 

1143 

1144 

1145class _ModelAllowAll(BaseModel): 

1146 if PYDANTIC_V2: 

1147 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow') 

1148 elif not TYPE_CHECKING: 

1149 

1150 class Config(BaseModel.Config): 

1151 extra: pydantic.Extra = pydantic.Extra.allow 

1152 

1153 

1154class PythonNames(BaseModel): 

1155 def client_class(self, _for_async: bool) -> str: 

1156 return 'Prisma' 

1157 

1158 

1159class PythonData(GenericData[Config]): 

1160 """Data class including the default Prisma Client Python config""" 

1161 

1162 if not PYDANTIC_V2: 

1163 

1164 class Config(BaseConfig): 

1165 arbitrary_types_allowed: bool = True 

1166 json_encoders: Dict[Type[Any], Any] = { 

1167 Path: _pathlib_serializer, 

1168 machinery.ModuleSpec: _module_spec_serializer, 

1169 } 

1170 keep_untouched: Tuple[Type[Any], ...] = (cached_property,) 

1171 

1172 names: PythonNames = PythonNames() 

1173 

1174 

1175class DefaultData(GenericData[_EmptyModel]): 

1176 """Data class without any config options""" 

1177 

1178 

1179# this has to be defined as a type alias instead of a class 

1180# as its purpose is to signify that the data is config agnostic 

1181AnyData = GenericData[Any] 

1182 

1183model_rebuild(Enum) 

1184model_rebuild(DMMF) 

1185model_rebuild(GenericData) 

1186model_rebuild(Field) 

1187model_rebuild(Model) 

1188model_rebuild(Datamodel) 

1189model_rebuild(Generator) 

1190model_rebuild(Datasource) 

1191 

1192 

1193from .errors import ( 

1194 TemplateError, 

1195 PartialTypeGeneratorError, 

1196) 

1197from .schema import Schema, ClientTypes