Coverage for src/prisma/generator/models.py: 94%

641 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1import os 

2import sys 

3import enum 

4import textwrap 

5import importlib 

6from typing import ( 

7 TYPE_CHECKING, 

8 Any, 

9 Dict, 

10 List, 

11 Type, 

12 Tuple, 

13 Union, 

14 Generic, 

15 TypeVar, 

16 ClassVar, 

17 Iterable, 

18 Iterator, 

19 NoReturn, 

20 Optional, 

21 cast, 

22) 

23from keyword import iskeyword 

24from pathlib import Path 

25from importlib import util as importlib_util, machinery 

26from itertools import chain 

27from contextvars import ContextVar 

28from importlib.abc import InspectLoader 

29from typing_extensions import Annotated, override 

30 

31import click 

32import pydantic 

33from pydantic.fields import PrivateAttr 

34 

35from .. import config 

36from .utils import Faker, Sampler, clean_multiline 

37from ..utils import DEBUG_GENERATOR, assert_never 

38from ..errors import UnsupportedListTypeError 

39from .._compat import ( 

40 PYDANTIC_V2, 

41 Field as FieldInfo, 

42 BaseConfig, 

43 ConfigDict, 

44 BaseSettings, 

45 GenericModel, 

46 PlainSerializer, 

47 BaseSettingsConfig, 

48 model_dict, 

49 model_parse, 

50 model_rebuild, 

51 root_validator, 

52 cached_property, 

53 field_validator, 

54) 

55from .._constants import QUERY_BUILDER_ALIASES 

56from ._dsl_parser import parse_schema_dsl 

57 

58__all__ = ( 

59 'AnyData', 

60 'PythonData', 

61 'DefaultData', 

62 'GenericData', 

63) 

64 

65_ModelT = TypeVar('_ModelT', bound=pydantic.BaseModel) 

66 

67# NOTE: this does not represent all the data that is passed by prisma 

68 

69ATOMIC_FIELD_TYPES = ['Int', 'BigInt', 'Float'] 

70 

71TYPE_MAPPING = { 

72 'String': '_str', 

73 'Bytes': "'fields.Base64'", 

74 'DateTime': 'datetime.datetime', 

75 'Boolean': '_bool', 

76 'Int': '_int', 

77 'Float': '_float', 

78 'BigInt': '_int', 

79 'Json': "'fields.Json'", 

80 'Decimal': 'decimal.Decimal', 

81} 

82FILTER_TYPES = [ 

83 'String', 

84 'Bytes', 

85 'DateTime', 

86 'Boolean', 

87 'Int', 

88 'BigInt', 

89 'Float', 

90 'Json', 

91 'Decimal', 

92] 

93RECURSIVE_TYPE_DEPTH_WARNING = """Some types are disabled by default due to being incompatible with Mypy, it is highly recommended 

94to use Pyright instead and configure Prisma Python to use recursive types. To re-enable certain types:""" 

95 

96RECURSIVE_TYPE_DEPTH_WARNING_DESC = """ 

97generator client { 

98 provider = "prisma-client-py" 

99 recursive_type_depth = -1 

100} 

101 

102If you need to use Mypy, you can also disable this message by explicitly setting the default value: 

103 

104generator client { 

105 provider = "prisma-client-py" 

106 recursive_type_depth = 5 

107} 

108 

109For more information see: https://prisma-client-py.readthedocs.io/en/stable/reference/limitations/#default-type-limitations 

110""" 

111 

112FAKER: Faker = Faker() 

113 

114 

115ConfigT = TypeVar('ConfigT', bound=pydantic.BaseModel) 

116 

117# Although we should just be able to access the config from the datamodel 

118# we have to do some validation that requires access to the config, this is difficult 

119# with heavily nested models as our current workaround only sets the datamodel context 

120# post-validation meaning we cannot access it in validators. To get around this we have 

121# a separate config context. 

122# TODO: better solution 

123data_ctx: ContextVar['AnyData'] = ContextVar('data_ctx') 

124config_ctx: ContextVar['Config'] = ContextVar('config_ctx') 

125 

126 

127def get_datamodel() -> 'Datamodel': 

128 return data_ctx.get().dmmf.datamodel 

129 

130 

131# typed to ensure the caller has to handle the cases where: 

132# - a custom generator config is being used 

133# - the config is invalid and therefore could not be set 

134def get_config() -> Union[None, pydantic.BaseModel, 'Config']: 

135 return config_ctx.get(None) 

136 

137 

138def get_list_types() -> Iterable[Tuple[str, str]]: 

139 # WARNING: do not edit this function without also editing Field.is_supported_scalar_list_type() 

140 return chain( 

141 ((t, TYPE_MAPPING[t]) for t in FILTER_TYPES), 

142 ((enum.name, f"'enums.{enum.name}'") for enum in get_datamodel().enums), 

143 ) 

144 

145 

146def sql_param(num: int = 1) -> str: 

147 # TODO: add case for sqlserver 

148 active_provider = data_ctx.get().datasources[0].active_provider 

149 if active_provider == 'postgresql': 

150 return f'${num}' 

151 

152 # TODO: test 

153 if active_provider == 'mongodb': # pragma: no cover 

154 raise RuntimeError('no-op') 

155 

156 # SQLite and MySQL use this style so just default to it 

157 return '?' 

158 

159 

160def raise_err(msg: str) -> NoReturn: 

161 raise TemplateError(msg) 

162 

163 

164def type_as_string(typ: str) -> str: 

165 """Ensure a type string is wrapped with a string, e.g. 

166 

167 enums.Role -> 'enums.Role' 

168 """ 

169 # TODO: use this function internally in this module 

170 if not typ.startswith("'") and not typ.startswith('"'): 

171 return f"'{typ}'" 

172 return typ 

173 

174 

175def format_documentation(doc: str, indent: int = 4) -> str: 

176 """Format a schema comment by indenting nested lines, e.g. 

177 

178 '''Foo 

179 Bar''' 

180 

181 Becomes 

182 

183 '''Foo 

184 Bar 

185 ''' 

186 """ 

187 if not doc: 187 ↛ 189line 187 didn't jump to line 189, because the condition on line 187 was never true

188 # empty string, nothing to do 

189 return doc 

190 

191 prefix = ' ' * indent 

192 first, *rest = doc.splitlines() 

193 return '\n'.join( 

194 [ 

195 first, 

196 *[textwrap.indent(line, prefix) for line in rest], 

197 prefix, 

198 ] 

199 ) 

200 

201 

202def _module_spec_serializer(spec: machinery.ModuleSpec) -> str: 

203 assert spec.origin is not None, 'Cannot serialize module with no origin' 

204 return spec.origin 

205 

206 

207def _pathlib_serializer(path: Path) -> str: 

208 return str(path.absolute()) 

209 

210 

211def _recursive_type_depth_factory() -> int: 

212 click.echo( 

213 click.style( 

214 f'\n{RECURSIVE_TYPE_DEPTH_WARNING}', 

215 fg='yellow', 

216 ) 

217 ) 

218 click.echo(f'{RECURSIVE_TYPE_DEPTH_WARNING_DESC}\n') 

219 return 5 

220 

221 

222class BaseModel(pydantic.BaseModel): 

223 if PYDANTIC_V2: 

224 model_config: ClassVar[ConfigDict] = ConfigDict( 

225 arbitrary_types_allowed=True, 

226 ignored_types=(cached_property,), 

227 ) 

228 else: 

229 

230 class Config(BaseConfig): 

231 arbitrary_types_allowed: bool = True 

232 json_encoders: Dict[Type[Any], Any] = { 

233 Path: _pathlib_serializer, 

234 machinery.ModuleSpec: _module_spec_serializer, 

235 } 

236 keep_untouched: Tuple[Type[Any], ...] = (cached_property,) 

237 

238 

239class InterfaceChoices(str, enum.Enum): 

240 sync = 'sync' 

241 asyncio = 'asyncio' 

242 

243 

244class EngineType(str, enum.Enum): 

245 binary = 'binary' 

246 library = 'library' 

247 dataproxy = 'dataproxy' 

248 

249 @override 

250 def __str__(self) -> str: 

251 return self.value 

252 

253 

254class Module(BaseModel): 

255 if TYPE_CHECKING: 

256 spec: machinery.ModuleSpec 

257 else: 

258 if PYDANTIC_V2: 

259 spec: Annotated[ 

260 machinery.ModuleSpec, 

261 PlainSerializer(lambda x: _module_spec_serializer(x), return_type=str), 

262 ] 

263 else: 

264 spec: machinery.ModuleSpec 

265 

266 if PYDANTIC_V2: 

267 model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) 

268 else: 

269 

270 class Config(BaseModel.Config): 

271 arbitrary_types_allowed: bool = True 

272 

273 # for some reason this is needed in Pydantic v2 

274 @root_validator(pre=True, skip_on_failure=True) 

275 @classmethod 

276 def partial_type_generator_converter(cls, values: object) -> Any: 

277 if isinstance(values, str): 

278 return {'spec': values} 

279 return values 

280 

281 @field_validator('spec', pre=True, allow_reuse=True) 

282 @classmethod 

283 def spec_validator(cls, value: Optional[str]) -> machinery.ModuleSpec: 

284 spec: Optional[machinery.ModuleSpec] = None 

285 

286 # TODO: this should really work based off of the schema path 

287 # and this should suport checking just partial_types.py if we are in a `prisma` dir 

288 if value is None: 

289 value = 'prisma/partial_types.py' 

290 

291 path = Path.cwd().joinpath(value) 

292 if path.exists(): 

293 spec = importlib_util.spec_from_file_location('prisma.partial_type_generator', value) 

294 elif value.startswith('.'): 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true

295 raise ValueError(f'No file found at {value} and relative imports are not allowed.') 

296 else: 

297 try: 

298 spec = importlib_util.find_spec(value) 

299 except ModuleNotFoundError: 

300 spec = None 

301 

302 if spec is None: 

303 raise ValueError(f'Could not find a python file or module at {value}') 

304 

305 return spec 

306 

307 def run(self) -> None: 

308 importlib.invalidate_caches() 

309 mod = importlib_util.module_from_spec(self.spec) 

310 loader = self.spec.loader 

311 assert loader is not None, 'Expected an import loader to exist.' 

312 assert isinstance(loader, InspectLoader), f'Cannot execute module from loader type: {type(loader)}' 

313 

314 try: 

315 loader.exec_module(mod) 

316 except Exception as exc: 

317 raise PartialTypeGeneratorError() from exc 

318 

319 

320class GenericData(GenericModel, Generic[ConfigT]): 

321 """Root model for the data that prisma provides to the generator. 

322 

323 WARNING: only one instance of this class may exist at any given time and 

324 instances should only be constructed using the Data.parse_obj() method 

325 """ 

326 

327 datamodel: str 

328 version: str 

329 generator: 'Generator[ConfigT]' 

330 dmmf: 'DMMF' = FieldInfo(alias='dmmf') 

331 schema_path: Path = FieldInfo(alias='schemaPath') 

332 datasources: List['Datasource'] = FieldInfo(alias='datasources') 

333 other_generators: List['Generator[_ModelAllowAll]'] = FieldInfo(alias='otherGenerators') 

334 binary_paths: 'BinaryPaths' = FieldInfo(alias='binaryPaths', default_factory=lambda: BinaryPaths()) 334 ↛ exitline 334 didn't run the lambda on line 334

335 

336 if PYDANTIC_V2: 

337 

338 @root_validator(pre=False) 

339 def _set_ctx(self: _ModelT) -> _ModelT: 

340 data_ctx.set(cast('GenericData[ConfigT]', self)) 

341 return self 

342 

343 else: 

344 

345 @classmethod 

346 @override 

347 def parse_obj(cls, obj: Any) -> 'GenericData[ConfigT]': 

348 data = super().parse_obj(obj) # pyright: ignore[reportDeprecated] 

349 data_ctx.set(data) 

350 return data 

351 

352 def to_params(self) -> Dict[str, Any]: 

353 """Get the parameters that should be sent to Jinja templates""" 

354 params = vars(self) 

355 params['type_schema'] = Schema.from_data(self) 

356 params['client_types'] = ClientTypes.from_data(self) 

357 

358 # add utility functions 

359 for func in [ 

360 sql_param, 

361 raise_err, 

362 type_as_string, 

363 get_list_types, 

364 clean_multiline, 

365 format_documentation, 

366 model_dict, 

367 ]: 

368 params[func.__name__] = func 

369 

370 return params 

371 

372 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True) 

373 @classmethod 

374 def validate_version(cls, values: Dict[Any, Any]) -> Dict[Any, Any]: 

375 # TODO: test this 

376 version = values.get('version') 

377 if not DEBUG_GENERATOR and version != config.expected_engine_version: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true

378 raise ValueError( 

379 f'Prisma Client Python expected Prisma version: {config.expected_engine_version} ' 

380 f'but got: {version}\n' 

381 ' If this is intentional, set the PRISMA_PY_DEBUG_GENERATOR environment ' 

382 'variable to 1 and try again.\n' 

383 f' If you are using the Node CLI then you must switch to v{config.prisma_version}, e.g. ' 

384 f'npx prisma@{config.prisma_version} generate\n' 

385 ' or generate the client using the Python CLI, e.g. python3 -m prisma generate' 

386 ) 

387 return values 

388 

389 

390class BinaryPaths(BaseModel): 

391 """This class represents the paths to engine binaries. 

392 

393 Each property in this class is a mapping of platform name to absolute path, for example: 

394 

395 ```py 

396 # This is what will be set on an M1 chip if there are no other `binaryTargets` set 

397 binary_paths.query_engine == { 

398 'darwin-arm64': '/Users/robert/.cache/prisma-python/binaries/3.13.0/efdf9b1183dddfd4258cd181a72125755215ab7b/node_modules/prisma/query-engine-darwin-arm64' 

399 } 

400 ``` 

401 

402 This is only available if the generator explicitly requests them using the `requires_engines` manifest property. 

403 """ 

404 

405 query_engine: Dict[str, str] = FieldInfo( 

406 default_factory=dict, 

407 alias='queryEngine', 

408 ) 

409 introspection_engine: Dict[str, str] = FieldInfo( 

410 default_factory=dict, 

411 alias='introspectionEngine', 

412 ) 

413 migration_engine: Dict[str, str] = FieldInfo( 

414 default_factory=dict, 

415 alias='migrationEngine', 

416 ) 

417 libquery_engine: Dict[str, str] = FieldInfo( 

418 default_factory=dict, 

419 alias='libqueryEngine', 

420 ) 

421 prisma_format: Dict[str, str] = FieldInfo( 

422 default_factory=dict, 

423 alias='prismaFmt', 

424 ) 

425 

426 if PYDANTIC_V2: 

427 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow') 

428 else: 

429 

430 class Config(BaseModel.Config): # pyright: ignore[reportDeprecated] 

431 extra: Any = ( 

432 pydantic.Extra.allow # pyright: ignore[reportDeprecated] 

433 ) 

434 

435 

436class Datasource(BaseModel): 

437 # TODO: provider enums 

438 name: str 

439 provider: str 

440 active_provider: str = FieldInfo(alias='activeProvider') 

441 url: 'OptionalValueFromEnvVar' 

442 

443 source_file_path: Optional[Path] = FieldInfo(alias='sourceFilePath') 

444 

445 

446class Generator(GenericModel, Generic[ConfigT]): 

447 name: str 

448 output: 'ValueFromEnvVar' 

449 provider: 'OptionalValueFromEnvVar' 

450 config: ConfigT 

451 binary_targets: List['ValueFromEnvVar'] = FieldInfo(alias='binaryTargets') 

452 preview_features: List[str] = FieldInfo(alias='previewFeatures') 

453 

454 @field_validator('binary_targets') 

455 @classmethod 

456 def warn_binary_targets(cls, targets: List['ValueFromEnvVar']) -> List['ValueFromEnvVar']: 

457 # Prisma by default sends one binary target which is the current platform. 

458 if len(targets) > 1: 

459 click.echo( 

460 click.style( 

461 'Warning: ' + 'The binaryTargets option is not officially supported by Prisma Client Python.', 

462 fg='yellow', 

463 ), 

464 file=sys.stdout, 

465 ) 

466 

467 return targets 

468 

469 def has_preview_feature(self, feature: str) -> bool: 

470 return feature in self.preview_features 

471 

472 

473class ValueFromEnvVar(BaseModel): 

474 value: str 

475 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar') 

476 

477 

478class OptionalValueFromEnvVar(BaseModel): 

479 value: Optional[str] = None 

480 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar') 

481 

482 def resolve(self) -> str: 

483 value = self.value 

484 if value is not None: 

485 return value 

486 

487 env_var = self.from_env_var 

488 assert env_var is not None, 'from_env_var should not be None' 

489 value = os.environ.get(env_var) 

490 if value is None: 490 ↛ 491line 490 didn't jump to line 491, because the condition on line 490 was never true

491 raise RuntimeError(f'Environment variable not found: {env_var}') 

492 

493 return value 

494 

495 

496class Config(BaseSettings): 

497 """Custom generator config options.""" 

498 

499 interface: InterfaceChoices = FieldInfo(default=InterfaceChoices.asyncio, env='PRISMA_PY_CONFIG_INTERFACE') 

500 partial_type_generator: Optional[Module] = FieldInfo(default=None, env='PRISMA_PY_CONFIG_PARTIAL_TYPE_GENERATOR') 

501 recursive_type_depth: int = FieldInfo( 

502 default_factory=_recursive_type_depth_factory, 

503 env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH', 

504 ) 

505 engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE') 

506 

507 # this should be a list of experimental features 

508 # https://github.com/prisma/prisma/issues/12442 

509 enable_experimental_decimal: bool = FieldInfo(default=False, env='PRISMA_PY_CONFIG_ENABLE_EXPERIMENTAL_DECIMAL') 

510 

511 # this seems to be the only good method for setting the contextvar as 

512 # we don't control the actual construction of the object like we do for 

513 # the Data model. 

514 # we do not expose this to type checkers so that the generated __init__ 

515 # signature is preserved. 

516 if not TYPE_CHECKING: 

517 

518 def __init__(self, **kwargs: object) -> None: 

519 super().__init__(**kwargs) 

520 config_ctx.set(self) 

521 

522 if PYDANTIC_V2: 

523 model_config: ClassVar[ConfigDict] = ConfigDict( 

524 extra='forbid', 

525 use_enum_values=True, 

526 populate_by_name=True, 

527 ) 

528 else: 

529 if not TYPE_CHECKING: 

530 

531 class Config(BaseSettingsConfig): 

532 extra: pydantic.Extra = pydantic.Extra.forbid 

533 use_enum_values: bool = True 

534 env_prefix: str = 'prisma_py_config_' 

535 allow_population_by_field_name: bool = True 

536 

537 @classmethod 

538 def customise_sources(cls, init_settings, env_settings, file_secret_settings): 

539 # prioritise env settings over init settings 

540 return env_settings, init_settings, file_secret_settings 

541 

542 @root_validator(pre=True, skip_on_failure=True) 

543 @classmethod 

544 def transform_engine_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

545 # prioritise env variable over schema option 

546 engine_type = os.environ.get('PRISMA_CLIENT_ENGINE_TYPE') 

547 if engine_type is None: 547 ↛ 551line 547 didn't jump to line 551, because the condition on line 547 was never false

548 engine_type = values.get('engineType') 

549 

550 # only add engine_type if it is present 

551 if engine_type is not None: 

552 values['engine_type'] = engine_type 

553 values.pop('engineType', None) 

554 

555 return values 

556 

557 @root_validator(pre=True, skip_on_failure=True) 

558 @classmethod 

559 def removed_http_option_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

560 http = values.get('http') 

561 if http is not None: 

562 if http in {'aiohttp', 'httpx-async'}: 

563 option = 'asyncio' 

564 elif http in {'requests', 'httpx-sync'}: 

565 option = 'sync' 

566 else: # pragma: no cover 

567 # invalid http option, let pydantic handle the error 

568 return values 

569 

570 raise ValueError( 

571 'The http option has been removed in favour of the interface option.\n' 

572 ' Please remove the http option from your Prisma schema and replace it with:\n' 

573 f' interface = "{option}"' 

574 ) 

575 return values 

576 

577 if PYDANTIC_V2: 

578 

579 @root_validator(pre=True, skip_on_failure=True) 

580 @classmethod 

581 def partial_type_generator_converter(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

582 # ensure env resolving happens 

583 values = cast(Dict[str, Any], cls.root_validator(values)) # type: ignore 

584 

585 value = values.get('partial_type_generator') 

586 

587 try: 

588 values['partial_type_generator'] = Module( 

589 spec=value # pyright: ignore[reportArgumentType] 

590 ) 

591 except ValueError: 

592 if value is None: 

593 # no config value passed and the default location was not found 

594 return values 

595 raise 

596 

597 return values 

598 

599 else: 

600 

601 @field_validator('partial_type_generator', pre=True, always=True, allow_reuse=True) 

602 @classmethod 

603 def _partial_type_generator_converter(cls, value: Optional[str]) -> Optional[Module]: 

604 try: 

605 return Module( 

606 spec=value # pyright: ignore[reportArgumentType] 

607 ) 

608 except ValueError: 

609 if value is None: 

610 # no config value passed and the default location was not found 

611 return None 

612 raise 

613 

614 @field_validator('recursive_type_depth', always=True, allow_reuse=True) 

615 @classmethod 

616 def recursive_type_depth_validator(cls, value: int) -> int: 

617 if value < -1 or value in {0, 1}: 

618 raise ValueError('Value must equal -1 or be greater than 1.') 

619 return value 

620 

621 @field_validator('engine_type', always=True, allow_reuse=True) 

622 @classmethod 

623 def engine_type_validator(cls, value: EngineType) -> EngineType: 

624 if value == EngineType.binary: 

625 return value 

626 elif value == EngineType.dataproxy: # pragma: no cover 

627 raise ValueError('Prisma Client Python does not support the Prisma Data Proxy yet.') 

628 elif value == EngineType.library: # pragma: no cover 

629 raise ValueError('Prisma Client Python does not support native engine bindings yet.') 

630 else: # pragma: no cover 

631 assert_never(value) 

632 

633 

634class DMMFEnumType(BaseModel): 

635 name: str 

636 values: List[object] 

637 

638 

639class DMMFEnumTypes(BaseModel): 

640 prisma: List[DMMFEnumType] 

641 

642 

643class PrismaSchema(BaseModel): 

644 enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes') 

645 

646 

647class DMMF(BaseModel): 

648 datamodel: 'Datamodel' 

649 prisma_schema: PrismaSchema = FieldInfo(alias='schema') 

650 

651 

652class Datamodel(BaseModel): 

653 enums: List['Enum'] 

654 models: List['Model'] 

655 

656 # not implemented yet 

657 types: List[object] 

658 

659 @field_validator('types') 

660 @classmethod 

661 def no_composite_types_validator(cls, types: List[object]) -> object: 

662 if types: 

663 raise ValueError( 

664 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314' 

665 ) 

666 

667 return types 

668 

669 

670class Enum(BaseModel): 

671 name: str 

672 db_name: Optional[str] = FieldInfo(alias='dbName') 

673 values: List['EnumValue'] 

674 

675 

676class EnumValue(BaseModel): 

677 name: str 

678 db_name: Optional[str] = FieldInfo(alias='dbName') 

679 

680 

681class ModelExtension(BaseModel): 

682 instance_name: Optional[str] = None 

683 

684 @field_validator('instance_name') 

685 @classmethod 

686 def instance_name_validator(cls, name: Optional[str]) -> Optional[str]: 

687 if not name: 687 ↛ 688line 687 didn't jump to line 688, because the condition on line 687 was never true

688 return name 

689 

690 if not name.isidentifier(): 

691 raise ValueError(f'Custom Model instance_name "{name}" is not a valid Python identifier') 

692 

693 return name 

694 

695 

696class Model(BaseModel): 

697 name: str 

698 documentation: Optional[str] = None 

699 db_name: Optional[str] = FieldInfo(alias='dbName') 

700 is_generated: bool = FieldInfo(alias='isGenerated') 

701 compound_primary_key: Optional['PrimaryKey'] = FieldInfo(alias='primaryKey') 

702 unique_indexes: List['UniqueIndex'] = FieldInfo(alias='uniqueIndexes') 

703 all_fields: List['Field'] = FieldInfo(alias='fields') 

704 

705 # stores the parsed DSL, not an actual field defined by prisma 

706 extension: Optional[ModelExtension] = None 

707 

708 _sampler: Sampler = PrivateAttr() 

709 

710 def __init__(self, **data: Any) -> None: 

711 super().__init__(**data) 

712 self._sampler = Sampler(self) 

713 

714 @root_validator(pre=True, allow_reuse=True) 

715 @classmethod 

716 def validate_dsl_extension(cls, values: Dict[Any, Any]) -> Dict[Any, Any]: 

717 documentation = values.get('documentation') 

718 if not documentation: 

719 return values 

720 

721 parsed = parse_schema_dsl(documentation) 

722 if parsed['type'] == 'invalid': 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true

723 raise ValueError(parsed['error']) 

724 

725 if parsed['type'] == 'ok': 725 ↛ 728line 725 didn't jump to line 728, because the condition on line 725 was never false

726 values['extension'] = model_parse(ModelExtension, parsed['value']['arguments']) 

727 

728 return values 

729 

730 @field_validator('name') 

731 @classmethod 

732 def name_validator(cls, name: str) -> str: 

733 if iskeyword(name): 

734 raise ValueError( 

735 f'Model name "{name}" shadows a Python keyword; ' 

736 f'use a different model name with \'@@map("{name}")\'.' 

737 ) 

738 

739 if iskeyword(name.lower()): 

740 raise ValueError( 

741 f'Model name "{name}" results in a client property that shadows a Python keyword; ' 

742 f'use a different model name with \'@@map("{name}")\'.' 

743 ) 

744 

745 return name 

746 

747 @property 

748 def related_models(self) -> Iterator['Model']: 

749 models = get_datamodel().models 

750 for field in self.relational_fields: 

751 for model in models: 

752 if field.type == model.name: 

753 yield model 

754 

755 @property 

756 def relational_fields(self) -> Iterator['Field']: 

757 for field in self.all_fields: 

758 if field.is_relational: 

759 yield field 

760 

761 @property 

762 def scalar_fields(self) -> Iterator['Field']: 

763 for field in self.all_fields: 

764 if not field.is_relational: 

765 yield field 

766 

767 @property 

768 def atomic_fields(self) -> Iterator['Field']: 

769 for field in self.all_fields: 

770 if field.type in ATOMIC_FIELD_TYPES: 

771 yield field 

772 

773 @property 

774 def required_array_fields(self) -> Iterator['Field']: 

775 for field in self.all_fields: 

776 if field.is_list and not field.relation_name and field.is_required: 

777 yield field 

778 

779 # TODO: support combined unique constraints 

780 @cached_property 

781 def id_field(self) -> Optional['Field']: 

782 """Find a field that can be passed to the model's `WhereUnique` filter""" 

783 for field in self.scalar_fields: # pragma: no branch 

784 if field.is_id or field.is_unique: 

785 return field 

786 return None 

787 

788 @property 

789 def has_relational_fields(self) -> bool: 

790 try: 

791 next(self.relational_fields) 

792 except StopIteration: 

793 return False 

794 else: 

795 return True 

796 

797 @property 

798 def instance_name(self) -> str: 

799 """The name of this model in the generated client class, e.g. 

800 

801 `User` -> `Prisma().user` 

802 """ 

803 if self.extension and self.extension.instance_name: 

804 return self.extension.instance_name 

805 

806 return self.name.lower() 

807 

808 @property 

809 def plural_name(self) -> str: 

810 name = self.instance_name 

811 if name.endswith('s'): 

812 return name 

813 return f'{name}s' 

814 

815 def resolve_field(self, name: str) -> 'Field': 

816 for field in self.all_fields: 816 ↛ 820line 816 didn't jump to line 820, because the loop on line 816 didn't complete

817 if field.name == name: 

818 return field 

819 

820 raise LookupError(f'Could not find a field with name: {name}') 

821 

822 def sampler(self) -> Sampler: 

823 return self._sampler 

824 

825 

826class Constraint(BaseModel): 

827 name: str 

828 fields: List[str] 

829 

830 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True) 

831 @classmethod 

832 def resolve_name(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

833 name = values.get('name') 

834 if isinstance(name, str): 

835 return values 

836 

837 values['name'] = '_'.join(values['fields']) 

838 return values 

839 

840 

841class PrimaryKey(Constraint): 

842 pass 

843 

844 

845class UniqueIndex(Constraint): 

846 pass 

847 

848 

849class Field(BaseModel): 

850 name: str 

851 documentation: Optional[str] = None 

852 

853 # TODO: switch to enums 

854 kind: str 

855 type: str 

856 

857 is_id: bool = FieldInfo(alias='isId') 

858 is_list: bool = FieldInfo(alias='isList') 

859 is_unique: bool = FieldInfo(alias='isUnique') 

860 is_required: bool = FieldInfo(alias='isRequired') 

861 is_read_only: bool = FieldInfo(alias='isReadOnly') 

862 is_generated: bool = FieldInfo(alias='isGenerated') 

863 is_updated_at: bool = FieldInfo(alias='isUpdatedAt') 

864 

865 default: Optional[Union['DefaultValue', object, List[object]]] = None 

866 has_default_value: bool = FieldInfo(alias='hasDefaultValue') 

867 

868 relation_name: Optional[str] = FieldInfo(alias='relationName', default=None) 

869 relation_on_delete: Optional[str] = FieldInfo(alias='relationOnDelete', default=None) 

870 relation_to_fields: Optional[List[str]] = FieldInfo( 

871 alias='relationToFields', 

872 default=None, 

873 ) 

874 relation_from_fields: Optional[List[str]] = FieldInfo( 

875 alias='relationFromFields', 

876 default=None, 

877 ) 

878 

879 _last_sampled: Optional[str] = PrivateAttr() 

880 

881 @root_validator(pre=True, skip_on_failure=True) 

882 @classmethod 

883 def scalar_type_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: 

884 kind = values.get('kind') 

885 type_ = values.get('type') 

886 

887 if kind == 'scalar': 

888 if type_ is not None and type_ not in TYPE_MAPPING: 888 ↛ 889line 888 didn't jump to line 889, because the condition on line 888 was never true

889 raise ValueError(f'Unsupported scalar field type: {type_}') 

890 

891 return values 

892 

893 @field_validator('type') 

894 @classmethod 

895 def experimental_decimal_validator(cls, typ: str) -> str: 

896 if typ == 'Decimal': 

897 config = get_config() 

898 

899 # skip validating the experimental flag if we are 

900 # being called from a custom generator 

901 if isinstance(config, Config) and not config.enable_experimental_decimal: 

902 raise ValueError( 

903 'Support for the Decimal type is experimental\n' 

904 ' As such you must set the `enable_experimental_decimal` config flag to true\n' 

905 ' for more information see: https://github.com/RobertCraigie/prisma-client-py/issues/106' 

906 ) 

907 

908 return typ 

909 

910 @field_validator('name') 

911 @classmethod 

912 def name_validator(cls, name: str) -> str: 

913 if getattr(BaseModel, name, None): 

914 raise ValueError( 

915 f'Field name "{name}" shadows a BaseModel attribute; ' 

916 f'use a different field name with \'@map("{name}")\'.' 

917 ) 

918 

919 if iskeyword(name): 

920 raise ValueError( 

921 f'Field name "{name}" shadows a Python keyword; ' f'use a different field name with \'@map("{name}")\'.' 

922 ) 

923 

924 if name == 'prisma': 

925 raise ValueError( 

926 'Field name "prisma" shadows a Prisma Client Python method; ' 

927 'use a different field name with \'@map("prisma")\'.' 

928 ) 

929 

930 if name in QUERY_BUILDER_ALIASES: 

931 raise ValueError( 

932 f'Field name "{name}" shadows an internal keyword; ' 

933 f'use a different field name with \'@map("{name}")\'' 

934 ) 

935 

936 return name 

937 

938 # TODO: cache the properties 

939 @property 

940 def python_type(self) -> str: 

941 type_ = self._actual_python_type 

942 if self.is_list: 

943 return f'List[{type_}]' 

944 return type_ 

945 

946 @property 

947 def python_type_as_string(self) -> str: 

948 type_ = self._actual_python_type 

949 if self.is_list: 

950 type_ = type_.replace("'", "\\'") 

951 return f"'List[{type_}]'" 

952 

953 if not type_.startswith("'"): 

954 type_ = f"'{type_}'" 

955 

956 return type_ 

957 

958 @property 

959 def _actual_python_type(self) -> str: 

960 if self.kind == 'enum': 

961 return f"'enums.{self.type}'" 

962 

963 if self.kind == 'object': 

964 return f"'models.{self.type}'" 

965 

966 try: 

967 return TYPE_MAPPING[self.type] 

968 except KeyError as exc: 

969 # TODO: handle this better 

970 raise RuntimeError( 

971 f'Could not parse {self.name} due to unknown type: {self.type}', 

972 ) from exc 

973 

974 @property 

975 def create_input_type(self) -> str: 

976 if self.kind != 'object': 

977 return self.python_type 

978 

979 if self.is_list: 

980 return f"'{self.type}CreateManyNestedWithoutRelationsInput'" 

981 

982 return f"'{self.type}CreateNestedWithoutRelationsInput'" 

983 

984 @property 

985 def where_input_type(self) -> str: 

986 typ = self.type 

987 if self.is_relational: 

988 if self.is_list: 

989 return f"'{typ}ListRelationFilter'" 

990 return f"'{typ}RelationFilter'" 

991 

992 if self.is_list: 

993 self.check_supported_scalar_list_type() 

994 return f"'types.{typ}ListFilter'" 

995 

996 if typ in FILTER_TYPES: 

997 if self.is_optional: 

998 return f"Union[None, {self._actual_python_type}, 'types.{typ}Filter']" 

999 return f"Union[{self._actual_python_type}, 'types.{typ}Filter']" 

1000 

1001 return self.python_type 

1002 

1003 @property 

1004 def where_aggregates_input_type(self) -> str: 

1005 if self.is_relational: # pragma: no cover 

1006 raise RuntimeError('This type is not valid for relational fields') 

1007 

1008 typ = self.type 

1009 if typ in FILTER_TYPES: 

1010 return f"Union[{self._actual_python_type}, 'types.{typ}WithAggregatesFilter']" 

1011 return self.python_type 

1012 

1013 @property 

1014 def relational_args_type(self) -> str: 

1015 if self.is_list: 

1016 return f'FindMany{self.type}Args' 

1017 return f'{self.type}Args' 

1018 

1019 @property 

1020 def required_on_create(self) -> bool: 

1021 return ( 

1022 self.is_required 

1023 and not self.is_updated_at 

1024 and not self.has_default_value 

1025 and not self.relation_name 

1026 and not self.is_list 

1027 ) 

1028 

1029 @property 

1030 def is_optional(self) -> bool: 

1031 return not (self.is_required and not self.relation_name) 

1032 

1033 @property 

1034 def is_relational(self) -> bool: 

1035 return self.relation_name is not None 

1036 

1037 @property 

1038 def is_atomic(self) -> bool: 

1039 return self.type in ATOMIC_FIELD_TYPES 

1040 

1041 @property 

1042 def is_number(self) -> bool: 

1043 return self.type in {'Int', 'BigInt', 'Float'} 

1044 

1045 def maybe_optional(self, typ: str) -> str: 

1046 """Wrap the given type string within `Optional` if applicable""" 

1047 if self.is_required or self.is_relational: 

1048 return typ 

1049 return f'Optional[{typ}]' 

1050 

1051 def get_update_input_type(self) -> str: 

1052 if self.kind == 'object': 

1053 if self.is_list: 

1054 return f"'{self.type}UpdateManyWithoutRelationsInput'" 

1055 return f"'{self.type}UpdateOneWithoutRelationsInput'" 

1056 

1057 if self.is_list: 

1058 self.check_supported_scalar_list_type() 

1059 return f"'types.{self.type}ListUpdate'" 

1060 

1061 if self.is_atomic: 

1062 return f'Union[Atomic{self.type}Input, {self.python_type}]' 

1063 

1064 return self.python_type 

1065 

1066 def check_supported_scalar_list_type(self) -> None: 

1067 if self.type not in FILTER_TYPES and self.kind != 'enum': # pragma: no branch 

1068 raise UnsupportedListTypeError(self.type) 

1069 

1070 def get_relational_model(self) -> Optional['Model']: 

1071 if not self.is_relational: 1071 ↛ 1072line 1071 didn't jump to line 1072, because the condition on line 1071 was never true

1072 return None 

1073 

1074 name = self.type 

1075 for model in get_datamodel().models: 1075 ↛ 1078line 1075 didn't jump to line 1078, because the loop on line 1075 didn't complete

1076 if model.name == name: 

1077 return model 

1078 return None 

1079 

1080 def get_corresponding_enum(self) -> Optional['Enum']: 

1081 typ = self.type 

1082 for enum in get_datamodel().enums: 

1083 if enum.name == typ: 1083 ↛ 1082line 1083 didn't jump to line 1082, because the condition on line 1083 was never false

1084 return enum 

1085 return None # pragma: no cover 

1086 

1087 def get_sample_data(self, *, increment: bool = True) -> str: 

1088 # returning the same data that was last sampled is useful 

1089 # for documenting methods like upsert() where data is duplicated 

1090 if not increment and self._last_sampled is not None: 

1091 return self._last_sampled 

1092 

1093 sampled = self._get_sample_data() 

1094 if self.is_list: 

1095 sampled = f'[{sampled}]' 

1096 

1097 self._last_sampled = sampled 

1098 return sampled 

1099 

1100 def _get_sample_data(self) -> str: 

1101 if self.is_relational: # pragma: no cover 

1102 raise RuntimeError('Data sampling for relational fields not supported yet') 

1103 

1104 if self.kind == 'enum': 

1105 enum = self.get_corresponding_enum() 

1106 assert enum is not None, self.type 

1107 return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}' 

1108 

1109 typ = self.type 

1110 if typ == 'Boolean': 

1111 return str(FAKER.boolean()) 

1112 elif typ == 'Int': 

1113 return str(FAKER.integer()) 

1114 elif typ == 'String': 

1115 return f"'{FAKER.string()}'" 

1116 elif typ == 'Float': 

1117 return f'{FAKER.integer()}.{FAKER.integer() // 10000}' 

1118 elif typ == 'BigInt': # pragma: no cover 

1119 return str(FAKER.integer() * 12) 

1120 elif typ == 'DateTime': 

1121 # TODO: random dates 

1122 return 'datetime.datetime.utcnow()' 

1123 elif typ == 'Json': 1123 ↛ 1124line 1123 didn't jump to line 1124, because the condition on line 1123 was never true

1124 return f"Json({{'{FAKER.string()}': True}})" 

1125 elif typ == 'Bytes': 1125 ↛ 1127line 1125 didn't jump to line 1127, because the condition on line 1125 was never false

1126 return f"Base64.encode(b'{FAKER.string()}')" 

1127 elif typ == 'Decimal': 

1128 return f"Decimal('{FAKER.integer()}.{FAKER.integer() // 10000}')" 

1129 else: # pragma: no cover 

1130 raise RuntimeError(f'Sample data not supported for {typ} yet') 

1131 

1132 

1133class DefaultValue(BaseModel): 

1134 args: Any = None 

1135 name: str 

1136 

1137 

1138class _EmptyModel(BaseModel): 

1139 if PYDANTIC_V2: 

1140 model_config: ClassVar[ConfigDict] = ConfigDict(extra='forbid') 

1141 elif not TYPE_CHECKING: 

1142 

1143 class Config(BaseModel.Config): 

1144 extra: pydantic.Extra = pydantic.Extra.forbid 

1145 

1146 

1147class _ModelAllowAll(BaseModel): 

1148 if PYDANTIC_V2: 

1149 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow') 

1150 elif not TYPE_CHECKING: 

1151 

1152 class Config(BaseModel.Config): 

1153 extra: pydantic.Extra = pydantic.Extra.allow 

1154 

1155 

1156class PythonNames(BaseModel): 

1157 def client_class(self, _for_async: bool) -> str: 

1158 return 'Prisma' 

1159 

1160 

1161class PythonData(GenericData[Config]): 

1162 """Data class including the default Prisma Client Python config""" 

1163 

1164 if not PYDANTIC_V2: 

1165 

1166 class Config(BaseConfig): 

1167 arbitrary_types_allowed: bool = True 

1168 json_encoders: Dict[Type[Any], Any] = { 

1169 Path: _pathlib_serializer, 

1170 machinery.ModuleSpec: _module_spec_serializer, 

1171 } 

1172 keep_untouched: Tuple[Type[Any], ...] = (cached_property,) 

1173 

1174 names: PythonNames = PythonNames() 

1175 

1176 

1177class DefaultData(GenericData[_EmptyModel]): 

1178 """Data class without any config options""" 

1179 

1180 

1181# this has to be defined as a type alias instead of a class 

1182# as its purpose is to signify that the data is config agnostic 

1183AnyData = GenericData[Any] 

1184 

1185model_rebuild(Enum) 

1186model_rebuild(DMMF) 

1187model_rebuild(GenericData) 

1188model_rebuild(Field) 

1189model_rebuild(Model) 

1190model_rebuild(Datamodel) 

1191model_rebuild(Generator) 

1192model_rebuild(Datasource) 

1193 

1194 

1195from .errors import ( 

1196 TemplateError, 

1197 PartialTypeGeneratorError, 

1198) 

1199from .schema import Schema, ClientTypes