Coverage for src/prisma/generator/models.py: 94%
640 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-04-28 15:17 +0000
1import os
2import sys
3import enum
4import textwrap
5import importlib
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Dict,
10 List,
11 Type,
12 Tuple,
13 Union,
14 Generic,
15 TypeVar,
16 ClassVar,
17 Iterable,
18 Iterator,
19 NoReturn,
20 Optional,
21 cast,
22)
23from keyword import iskeyword
24from pathlib import Path
25from importlib import util as importlib_util, machinery
26from itertools import chain
27from contextvars import ContextVar
28from importlib.abc import InspectLoader
29from typing_extensions import Annotated, override
31import click
32import pydantic
33from pydantic.fields import PrivateAttr
35from .. import config
36from .utils import Faker, Sampler, clean_multiline
37from ..utils import DEBUG_GENERATOR, assert_never
38from ..errors import UnsupportedListTypeError
39from .._compat import (
40 PYDANTIC_V2,
41 Field as FieldInfo,
42 BaseConfig,
43 ConfigDict,
44 BaseSettings,
45 GenericModel,
46 PlainSerializer,
47 BaseSettingsConfig,
48 model_dict,
49 model_parse,
50 model_rebuild,
51 root_validator,
52 cached_property,
53 field_validator,
54)
55from .._constants import QUERY_BUILDER_ALIASES
56from ._dsl_parser import parse_schema_dsl
58__all__ = (
59 'AnyData',
60 'PythonData',
61 'DefaultData',
62 'GenericData',
63)
65_ModelT = TypeVar('_ModelT', bound=pydantic.BaseModel)
67# NOTE: this does not represent all the data that is passed by prisma
69ATOMIC_FIELD_TYPES = ['Int', 'BigInt', 'Float']
71TYPE_MAPPING = {
72 'String': '_str',
73 'Bytes': "'fields.Base64'",
74 'DateTime': 'datetime.datetime',
75 'Boolean': '_bool',
76 'Int': '_int',
77 'Float': '_float',
78 'BigInt': '_int',
79 'Json': "'fields.Json'",
80 'Decimal': 'decimal.Decimal',
81}
82FILTER_TYPES = [
83 'String',
84 'Bytes',
85 'DateTime',
86 'Boolean',
87 'Int',
88 'BigInt',
89 'Float',
90 'Json',
91 'Decimal',
92]
93RECURSIVE_TYPE_DEPTH_WARNING = """Some types are disabled by default due to being incompatible with Mypy, it is highly recommended
94to use Pyright instead and configure Prisma Python to use recursive types. To re-enable certain types:"""
96RECURSIVE_TYPE_DEPTH_WARNING_DESC = """
97generator client {
98 provider = "prisma-client-py"
99 recursive_type_depth = -1
100}
102If you need to use Mypy, you can also disable this message by explicitly setting the default value:
104generator client {
105 provider = "prisma-client-py"
106 recursive_type_depth = 5
107}
109For more information see: https://prisma-client-py.readthedocs.io/en/stable/reference/limitations/#default-type-limitations
110"""
112FAKER: Faker = Faker()
115ConfigT = TypeVar('ConfigT', bound=pydantic.BaseModel)
117# Although we should just be able to access the config from the datamodel
118# we have to do some validation that requires access to the config, this is difficult
119# with heavily nested models as our current workaround only sets the datamodel context
120# post-validation meaning we cannot access it in validators. To get around this we have
121# a separate config context.
122# TODO: better solution
123data_ctx: ContextVar['AnyData'] = ContextVar('data_ctx')
124config_ctx: ContextVar['Config'] = ContextVar('config_ctx')
127def get_datamodel() -> 'Datamodel':
128 return data_ctx.get().dmmf.datamodel
131# typed to ensure the caller has to handle the cases where:
132# - a custom generator config is being used
133# - the config is invalid and therefore could not be set
134def get_config() -> Union[None, pydantic.BaseModel, 'Config']:
135 return config_ctx.get(None)
138def get_list_types() -> Iterable[Tuple[str, str]]:
139 # WARNING: do not edit this function without also editing Field.is_supported_scalar_list_type()
140 return chain(
141 ((t, TYPE_MAPPING[t]) for t in FILTER_TYPES),
142 ((enum.name, f"'enums.{enum.name}'") for enum in get_datamodel().enums),
143 )
146def sql_param(num: int = 1) -> str:
147 # TODO: add case for sqlserver
148 active_provider = data_ctx.get().datasources[0].active_provider
149 if active_provider == 'postgresql':
150 return f'${num}'
152 # TODO: test
153 if active_provider == 'mongodb': # pragma: no cover
154 raise RuntimeError('no-op')
156 # SQLite and MySQL use this style so just default to it
157 return '?'
160def raise_err(msg: str) -> NoReturn:
161 raise TemplateError(msg)
164def type_as_string(typ: str) -> str:
165 """Ensure a type string is wrapped with a string, e.g.
167 enums.Role -> 'enums.Role'
168 """
169 # TODO: use this function internally in this module
170 if not typ.startswith("'") and not typ.startswith('"'):
171 return f"'{typ}'"
172 return typ
175def format_documentation(doc: str, indent: int = 4) -> str:
176 """Format a schema comment by indenting nested lines, e.g.
178 '''Foo
179 Bar'''
181 Becomes
183 '''Foo
184 Bar
185 '''
186 """
187 if not doc: 187 ↛ 189line 187 didn't jump to line 189, because the condition on line 187 was never true
188 # empty string, nothing to do
189 return doc
191 prefix = ' ' * indent
192 first, *rest = doc.splitlines()
193 return '\n'.join(
194 [
195 first,
196 *[textwrap.indent(line, prefix) for line in rest],
197 prefix,
198 ]
199 )
202def _module_spec_serializer(spec: machinery.ModuleSpec) -> str:
203 assert spec.origin is not None, 'Cannot serialize module with no origin'
204 return spec.origin
207def _pathlib_serializer(path: Path) -> str:
208 return str(path.absolute())
211def _recursive_type_depth_factory() -> int:
212 click.echo(
213 click.style(
214 f'\n{RECURSIVE_TYPE_DEPTH_WARNING}',
215 fg='yellow',
216 )
217 )
218 click.echo(f'{RECURSIVE_TYPE_DEPTH_WARNING_DESC}\n')
219 return 5
222class BaseModel(pydantic.BaseModel):
223 if PYDANTIC_V2:
224 model_config: ClassVar[ConfigDict] = ConfigDict(
225 arbitrary_types_allowed=True,
226 ignored_types=(cached_property,),
227 )
228 else:
230 class Config(BaseConfig):
231 arbitrary_types_allowed: bool = True
232 json_encoders: Dict[Type[Any], Any] = {
233 Path: _pathlib_serializer,
234 machinery.ModuleSpec: _module_spec_serializer,
235 }
236 keep_untouched: Tuple[Type[Any], ...] = (cached_property,)
239class InterfaceChoices(str, enum.Enum):
240 sync = 'sync'
241 asyncio = 'asyncio'
244class EngineType(str, enum.Enum):
245 binary = 'binary'
246 library = 'library'
247 dataproxy = 'dataproxy'
249 @override
250 def __str__(self) -> str:
251 return self.value
254class Module(BaseModel):
255 if TYPE_CHECKING:
256 spec: machinery.ModuleSpec
257 else:
258 if PYDANTIC_V2:
259 spec: Annotated[
260 machinery.ModuleSpec,
261 PlainSerializer(lambda x: _module_spec_serializer(x), return_type=str),
262 ]
263 else:
264 spec: machinery.ModuleSpec
266 if PYDANTIC_V2:
267 model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
268 else:
270 class Config(BaseModel.Config):
271 arbitrary_types_allowed: bool = True
273 # for some reason this is needed in Pydantic v2
274 @root_validator(pre=True, skip_on_failure=True)
275 @classmethod
276 def partial_type_generator_converter(cls, values: object) -> Any:
277 if isinstance(values, str):
278 return {'spec': values}
279 return values
281 @field_validator('spec', pre=True, allow_reuse=True)
282 @classmethod
283 def spec_validator(cls, value: Optional[str]) -> machinery.ModuleSpec:
284 spec: Optional[machinery.ModuleSpec] = None
286 # TODO: this should really work based off of the schema path
287 # and this should suport checking just partial_types.py if we are in a `prisma` dir
288 if value is None:
289 value = 'prisma/partial_types.py'
291 path = Path.cwd().joinpath(value)
292 if path.exists():
293 spec = importlib_util.spec_from_file_location('prisma.partial_type_generator', value)
294 elif value.startswith('.'): 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true
295 raise ValueError(f'No file found at {value} and relative imports are not allowed.')
296 else:
297 try:
298 spec = importlib_util.find_spec(value)
299 except ModuleNotFoundError:
300 spec = None
302 if spec is None:
303 raise ValueError(f'Could not find a python file or module at {value}')
305 return spec
307 def run(self) -> None:
308 importlib.invalidate_caches()
309 mod = importlib_util.module_from_spec(self.spec)
310 loader = self.spec.loader
311 assert loader is not None, 'Expected an import loader to exist.'
312 assert isinstance(loader, InspectLoader), f'Cannot execute module from loader type: {type(loader)}'
314 try:
315 loader.exec_module(mod)
316 except Exception as exc:
317 raise PartialTypeGeneratorError() from exc
320class GenericData(GenericModel, Generic[ConfigT]):
321 """Root model for the data that prisma provides to the generator.
323 WARNING: only one instance of this class may exist at any given time and
324 instances should only be constructed using the Data.parse_obj() method
325 """
327 datamodel: str
328 version: str
329 generator: 'Generator[ConfigT]'
330 dmmf: 'DMMF' = FieldInfo(alias='dmmf')
331 schema_path: Path = FieldInfo(alias='schemaPath')
332 datasources: List['Datasource'] = FieldInfo(alias='datasources')
333 other_generators: List['Generator[_ModelAllowAll]'] = FieldInfo(alias='otherGenerators')
334 binary_paths: 'BinaryPaths' = FieldInfo(alias='binaryPaths', default_factory=lambda: BinaryPaths()) 334 ↛ exitline 334 didn't run the lambda on line 334
336 if PYDANTIC_V2:
338 @root_validator(pre=False)
339 def _set_ctx(self: _ModelT) -> _ModelT:
340 data_ctx.set(cast('GenericData[ConfigT]', self))
341 return self
343 else:
345 @classmethod
346 @override
347 def parse_obj(cls, obj: Any) -> 'GenericData[ConfigT]':
348 data = super().parse_obj(obj) # pyright: ignore[reportDeprecated]
349 data_ctx.set(data)
350 return data
352 def to_params(self) -> Dict[str, Any]:
353 """Get the parameters that should be sent to Jinja templates"""
354 params = vars(self)
355 params['type_schema'] = Schema.from_data(self)
356 params['client_types'] = ClientTypes.from_data(self)
358 # add utility functions
359 for func in [
360 sql_param,
361 raise_err,
362 type_as_string,
363 get_list_types,
364 clean_multiline,
365 format_documentation,
366 model_dict,
367 ]:
368 params[func.__name__] = func
370 return params
372 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
373 @classmethod
374 def validate_version(cls, values: Dict[Any, Any]) -> Dict[Any, Any]:
375 # TODO: test this
376 version = values.get('version')
377 if not DEBUG_GENERATOR and version != config.expected_engine_version: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true
378 raise ValueError(
379 f'Prisma Client Python expected Prisma version: {config.expected_engine_version} '
380 f'but got: {version}\n'
381 ' If this is intentional, set the PRISMA_PY_DEBUG_GENERATOR environment '
382 'variable to 1 and try again.\n'
383 f' If you are using the Node CLI then you must switch to v{config.prisma_version}, e.g. '
384 f'npx prisma@{config.prisma_version} generate\n'
385 ' or generate the client using the Python CLI, e.g. python3 -m prisma generate'
386 )
387 return values
390class BinaryPaths(BaseModel):
391 """This class represents the paths to engine binaries.
393 Each property in this class is a mapping of platform name to absolute path, for example:
395 ```py
396 # This is what will be set on an M1 chip if there are no other `binaryTargets` set
397 binary_paths.query_engine == {
398 'darwin-arm64': '/Users/robert/.cache/prisma-python/binaries/3.13.0/efdf9b1183dddfd4258cd181a72125755215ab7b/node_modules/prisma/query-engine-darwin-arm64'
399 }
400 ```
402 This is only available if the generator explicitly requests them using the `requires_engines` manifest property.
403 """
405 query_engine: Dict[str, str] = FieldInfo(
406 default_factory=dict,
407 alias='queryEngine',
408 )
409 introspection_engine: Dict[str, str] = FieldInfo(
410 default_factory=dict,
411 alias='introspectionEngine',
412 )
413 migration_engine: Dict[str, str] = FieldInfo(
414 default_factory=dict,
415 alias='migrationEngine',
416 )
417 libquery_engine: Dict[str, str] = FieldInfo(
418 default_factory=dict,
419 alias='libqueryEngine',
420 )
421 prisma_format: Dict[str, str] = FieldInfo(
422 default_factory=dict,
423 alias='prismaFmt',
424 )
426 if PYDANTIC_V2:
427 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow')
428 else:
430 class Config(BaseModel.Config): # pyright: ignore[reportDeprecated]
431 extra: Any = (
432 pydantic.Extra.allow # pyright: ignore[reportDeprecated]
433 )
436class Datasource(BaseModel):
437 # TODO: provider enums
438 name: str
439 provider: str
440 active_provider: str = FieldInfo(alias='activeProvider')
441 url: 'OptionalValueFromEnvVar'
444class Generator(GenericModel, Generic[ConfigT]):
445 name: str
446 output: 'ValueFromEnvVar'
447 provider: 'OptionalValueFromEnvVar'
448 config: ConfigT
449 binary_targets: List['ValueFromEnvVar'] = FieldInfo(alias='binaryTargets')
450 preview_features: List[str] = FieldInfo(alias='previewFeatures')
452 @field_validator('binary_targets')
453 @classmethod
454 def warn_binary_targets(cls, targets: List['ValueFromEnvVar']) -> List['ValueFromEnvVar']:
455 # Prisma by default sends one binary target which is the current platform.
456 if len(targets) > 1:
457 click.echo(
458 click.style(
459 'Warning: ' + 'The binaryTargets option is not officially supported by Prisma Client Python.',
460 fg='yellow',
461 ),
462 file=sys.stdout,
463 )
465 return targets
467 def has_preview_feature(self, feature: str) -> bool:
468 return feature in self.preview_features
471class ValueFromEnvVar(BaseModel):
472 value: str
473 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')
476class OptionalValueFromEnvVar(BaseModel):
477 value: Optional[str] = None
478 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')
480 def resolve(self) -> str:
481 value = self.value
482 if value is not None:
483 return value
485 env_var = self.from_env_var
486 assert env_var is not None, 'from_env_var should not be None'
487 value = os.environ.get(env_var)
488 if value is None: 488 ↛ 489line 488 didn't jump to line 489, because the condition on line 488 was never true
489 raise RuntimeError(f'Environment variable not found: {env_var}')
491 return value
494class Config(BaseSettings):
495 """Custom generator config options."""
497 interface: InterfaceChoices = FieldInfo(default=InterfaceChoices.asyncio, env='PRISMA_PY_CONFIG_INTERFACE')
498 partial_type_generator: Optional[Module] = FieldInfo(default=None, env='PRISMA_PY_CONFIG_PARTIAL_TYPE_GENERATOR')
499 recursive_type_depth: int = FieldInfo(
500 default_factory=_recursive_type_depth_factory,
501 env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH',
502 )
503 engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE')
505 # this should be a list of experimental features
506 # https://github.com/prisma/prisma/issues/12442
507 enable_experimental_decimal: bool = FieldInfo(default=False, env='PRISMA_PY_CONFIG_ENABLE_EXPERIMENTAL_DECIMAL')
509 # this seems to be the only good method for setting the contextvar as
510 # we don't control the actual construction of the object like we do for
511 # the Data model.
512 # we do not expose this to type checkers so that the generated __init__
513 # signature is preserved.
514 if not TYPE_CHECKING:
516 def __init__(self, **kwargs: object) -> None:
517 super().__init__(**kwargs)
518 config_ctx.set(self)
520 if PYDANTIC_V2:
521 model_config: ClassVar[ConfigDict] = ConfigDict(
522 extra='forbid',
523 use_enum_values=True,
524 populate_by_name=True,
525 )
526 else:
527 if not TYPE_CHECKING:
529 class Config(BaseSettingsConfig):
530 extra: pydantic.Extra = pydantic.Extra.forbid
531 use_enum_values: bool = True
532 env_prefix: str = 'prisma_py_config_'
533 allow_population_by_field_name: bool = True
535 @classmethod
536 def customise_sources(cls, init_settings, env_settings, file_secret_settings):
537 # prioritise env settings over init settings
538 return env_settings, init_settings, file_secret_settings
540 @root_validator(pre=True, skip_on_failure=True)
541 @classmethod
542 def transform_engine_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
543 # prioritise env variable over schema option
544 engine_type = os.environ.get('PRISMA_CLIENT_ENGINE_TYPE')
545 if engine_type is None: 545 ↛ 549line 545 didn't jump to line 549, because the condition on line 545 was never false
546 engine_type = values.get('engineType')
548 # only add engine_type if it is present
549 if engine_type is not None:
550 values['engine_type'] = engine_type
551 values.pop('engineType', None)
553 return values
555 @root_validator(pre=True, skip_on_failure=True)
556 @classmethod
557 def removed_http_option_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
558 http = values.get('http')
559 if http is not None:
560 if http in {'aiohttp', 'httpx-async'}:
561 option = 'asyncio'
562 elif http in {'requests', 'httpx-sync'}:
563 option = 'sync'
564 else: # pragma: no cover
565 # invalid http option, let pydantic handle the error
566 return values
568 raise ValueError(
569 'The http option has been removed in favour of the interface option.\n'
570 ' Please remove the http option from your Prisma schema and replace it with:\n'
571 f' interface = "{option}"'
572 )
573 return values
575 if PYDANTIC_V2:
577 @root_validator(pre=True, skip_on_failure=True)
578 @classmethod
579 def partial_type_generator_converter(cls, values: Dict[str, Any]) -> Dict[str, Any]:
580 # ensure env resolving happens
581 values = cast(Dict[str, Any], cls.root_validator(values)) # type: ignore
583 value = values.get('partial_type_generator')
585 try:
586 values['partial_type_generator'] = Module(
587 spec=value # pyright: ignore[reportArgumentType]
588 )
589 except ValueError:
590 if value is None:
591 # no config value passed and the default location was not found
592 return values
593 raise
595 return values
597 else:
599 @field_validator('partial_type_generator', pre=True, always=True, allow_reuse=True)
600 @classmethod
601 def _partial_type_generator_converter(cls, value: Optional[str]) -> Optional[Module]:
602 try:
603 return Module(
604 spec=value # pyright: ignore[reportArgumentType]
605 )
606 except ValueError:
607 if value is None:
608 # no config value passed and the default location was not found
609 return None
610 raise
612 @field_validator('recursive_type_depth', always=True, allow_reuse=True)
613 @classmethod
614 def recursive_type_depth_validator(cls, value: int) -> int:
615 if value < -1 or value in {0, 1}:
616 raise ValueError('Value must equal -1 or be greater than 1.')
617 return value
619 @field_validator('engine_type', always=True, allow_reuse=True)
620 @classmethod
621 def engine_type_validator(cls, value: EngineType) -> EngineType:
622 if value == EngineType.binary:
623 return value
624 elif value == EngineType.dataproxy: # pragma: no cover
625 raise ValueError('Prisma Client Python does not support the Prisma Data Proxy yet.')
626 elif value == EngineType.library: # pragma: no cover
627 raise ValueError('Prisma Client Python does not support native engine bindings yet.')
628 else: # pragma: no cover
629 assert_never(value)
632class DMMFEnumType(BaseModel):
633 name: str
634 values: List[object]
637class DMMFEnumTypes(BaseModel):
638 prisma: List[DMMFEnumType]
641class PrismaSchema(BaseModel):
642 enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes')
645class DMMF(BaseModel):
646 datamodel: 'Datamodel'
647 prisma_schema: PrismaSchema = FieldInfo(alias='schema')
650class Datamodel(BaseModel):
651 enums: List['Enum']
652 models: List['Model']
654 # not implemented yet
655 types: List[object]
657 @field_validator('types')
658 @classmethod
659 def no_composite_types_validator(cls, types: List[object]) -> object:
660 if types:
661 raise ValueError(
662 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314'
663 )
665 return types
668class Enum(BaseModel):
669 name: str
670 db_name: Optional[str] = FieldInfo(alias='dbName')
671 values: List['EnumValue']
674class EnumValue(BaseModel):
675 name: str
676 db_name: Optional[str] = FieldInfo(alias='dbName')
679class ModelExtension(BaseModel):
680 instance_name: Optional[str] = None
682 @field_validator('instance_name')
683 @classmethod
684 def instance_name_validator(cls, name: Optional[str]) -> Optional[str]:
685 if not name: 685 ↛ 686line 685 didn't jump to line 686, because the condition on line 685 was never true
686 return name
688 if not name.isidentifier():
689 raise ValueError(f'Custom Model instance_name "{name}" is not a valid Python identifier')
691 return name
694class Model(BaseModel):
695 name: str
696 documentation: Optional[str] = None
697 db_name: Optional[str] = FieldInfo(alias='dbName')
698 is_generated: bool = FieldInfo(alias='isGenerated')
699 compound_primary_key: Optional['PrimaryKey'] = FieldInfo(alias='primaryKey')
700 unique_indexes: List['UniqueIndex'] = FieldInfo(alias='uniqueIndexes')
701 all_fields: List['Field'] = FieldInfo(alias='fields')
703 # stores the parsed DSL, not an actual field defined by prisma
704 extension: Optional[ModelExtension] = None
706 _sampler: Sampler = PrivateAttr()
708 def __init__(self, **data: Any) -> None:
709 super().__init__(**data)
710 self._sampler = Sampler(self)
712 @root_validator(pre=True, allow_reuse=True)
713 @classmethod
714 def validate_dsl_extension(cls, values: Dict[Any, Any]) -> Dict[Any, Any]:
715 documentation = values.get('documentation')
716 if not documentation:
717 return values
719 parsed = parse_schema_dsl(documentation)
720 if parsed['type'] == 'invalid': 720 ↛ 721line 720 didn't jump to line 721, because the condition on line 720 was never true
721 raise ValueError(parsed['error'])
723 if parsed['type'] == 'ok': 723 ↛ 726line 723 didn't jump to line 726, because the condition on line 723 was never false
724 values['extension'] = model_parse(ModelExtension, parsed['value']['arguments'])
726 return values
728 @field_validator('name')
729 @classmethod
730 def name_validator(cls, name: str) -> str:
731 if iskeyword(name):
732 raise ValueError(
733 f'Model name "{name}" shadows a Python keyword; '
734 f'use a different model name with \'@@map("{name}")\'.'
735 )
737 if iskeyword(name.lower()):
738 raise ValueError(
739 f'Model name "{name}" results in a client property that shadows a Python keyword; '
740 f'use a different model name with \'@@map("{name}")\'.'
741 )
743 return name
745 @property
746 def related_models(self) -> Iterator['Model']:
747 models = get_datamodel().models
748 for field in self.relational_fields:
749 for model in models:
750 if field.type == model.name:
751 yield model
753 @property
754 def relational_fields(self) -> Iterator['Field']:
755 for field in self.all_fields:
756 if field.is_relational:
757 yield field
759 @property
760 def scalar_fields(self) -> Iterator['Field']:
761 for field in self.all_fields:
762 if not field.is_relational:
763 yield field
765 @property
766 def atomic_fields(self) -> Iterator['Field']:
767 for field in self.all_fields:
768 if field.type in ATOMIC_FIELD_TYPES:
769 yield field
771 @property
772 def required_array_fields(self) -> Iterator['Field']:
773 for field in self.all_fields:
774 if field.is_list and not field.relation_name and field.is_required:
775 yield field
777 # TODO: support combined unique constraints
778 @cached_property
779 def id_field(self) -> Optional['Field']:
780 """Find a field that can be passed to the model's `WhereUnique` filter"""
781 for field in self.scalar_fields: # pragma: no branch
782 if field.is_id or field.is_unique:
783 return field
784 return None
786 @property
787 def has_relational_fields(self) -> bool:
788 try:
789 next(self.relational_fields)
790 except StopIteration:
791 return False
792 else:
793 return True
795 @property
796 def instance_name(self) -> str:
797 """The name of this model in the generated client class, e.g.
799 `User` -> `Prisma().user`
800 """
801 if self.extension and self.extension.instance_name:
802 return self.extension.instance_name
804 return self.name.lower()
806 @property
807 def plural_name(self) -> str:
808 name = self.instance_name
809 if name.endswith('s'):
810 return name
811 return f'{name}s'
813 def resolve_field(self, name: str) -> 'Field':
814 for field in self.all_fields: 814 ↛ 818line 814 didn't jump to line 818, because the loop on line 814 didn't complete
815 if field.name == name:
816 return field
818 raise LookupError(f'Could not find a field with name: {name}')
820 def sampler(self) -> Sampler:
821 return self._sampler
824class Constraint(BaseModel):
825 name: str
826 fields: List[str]
828 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
829 @classmethod
830 def resolve_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
831 name = values.get('name')
832 if isinstance(name, str):
833 return values
835 values['name'] = '_'.join(values['fields'])
836 return values
839class PrimaryKey(Constraint):
840 pass
843class UniqueIndex(Constraint):
844 pass
847class Field(BaseModel):
848 name: str
849 documentation: Optional[str] = None
851 # TODO: switch to enums
852 kind: str
853 type: str
855 is_id: bool = FieldInfo(alias='isId')
856 is_list: bool = FieldInfo(alias='isList')
857 is_unique: bool = FieldInfo(alias='isUnique')
858 is_required: bool = FieldInfo(alias='isRequired')
859 is_read_only: bool = FieldInfo(alias='isReadOnly')
860 is_generated: bool = FieldInfo(alias='isGenerated')
861 is_updated_at: bool = FieldInfo(alias='isUpdatedAt')
863 default: Optional[Union['DefaultValue', object, List[object]]] = None
864 has_default_value: bool = FieldInfo(alias='hasDefaultValue')
866 relation_name: Optional[str] = FieldInfo(alias='relationName', default=None)
867 relation_on_delete: Optional[str] = FieldInfo(alias='relationOnDelete', default=None)
868 relation_to_fields: Optional[List[str]] = FieldInfo(
869 alias='relationToFields',
870 default=None,
871 )
872 relation_from_fields: Optional[List[str]] = FieldInfo(
873 alias='relationFromFields',
874 default=None,
875 )
877 _last_sampled: Optional[str] = PrivateAttr()
879 @root_validator(pre=True, skip_on_failure=True)
880 @classmethod
881 def scalar_type_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
882 kind = values.get('kind')
883 type_ = values.get('type')
885 if kind == 'scalar':
886 if type_ is not None and type_ not in TYPE_MAPPING: 886 ↛ 887line 886 didn't jump to line 887, because the condition on line 886 was never true
887 raise ValueError(f'Unsupported scalar field type: {type_}')
889 return values
891 @field_validator('type')
892 @classmethod
893 def experimental_decimal_validator(cls, typ: str) -> str:
894 if typ == 'Decimal':
895 config = get_config()
897 # skip validating the experimental flag if we are
898 # being called from a custom generator
899 if isinstance(config, Config) and not config.enable_experimental_decimal:
900 raise ValueError(
901 'Support for the Decimal type is experimental\n'
902 ' As such you must set the `enable_experimental_decimal` config flag to true\n'
903 ' for more information see: https://github.com/RobertCraigie/prisma-client-py/issues/106'
904 )
906 return typ
908 @field_validator('name')
909 @classmethod
910 def name_validator(cls, name: str) -> str:
911 if getattr(BaseModel, name, None):
912 raise ValueError(
913 f'Field name "{name}" shadows a BaseModel attribute; '
914 f'use a different field name with \'@map("{name}")\'.'
915 )
917 if iskeyword(name):
918 raise ValueError(
919 f'Field name "{name}" shadows a Python keyword; ' f'use a different field name with \'@map("{name}")\'.'
920 )
922 if name == 'prisma':
923 raise ValueError(
924 'Field name "prisma" shadows a Prisma Client Python method; '
925 'use a different field name with \'@map("prisma")\'.'
926 )
928 if name in QUERY_BUILDER_ALIASES:
929 raise ValueError(
930 f'Field name "{name}" shadows an internal keyword; '
931 f'use a different field name with \'@map("{name}")\''
932 )
934 return name
936 # TODO: cache the properties
937 @property
938 def python_type(self) -> str:
939 type_ = self._actual_python_type
940 if self.is_list:
941 return f'List[{type_}]'
942 return type_
944 @property
945 def python_type_as_string(self) -> str:
946 type_ = self._actual_python_type
947 if self.is_list:
948 type_ = type_.replace("'", "\\'")
949 return f"'List[{type_}]'"
951 if not type_.startswith("'"):
952 type_ = f"'{type_}'"
954 return type_
956 @property
957 def _actual_python_type(self) -> str:
958 if self.kind == 'enum':
959 return f"'enums.{self.type}'"
961 if self.kind == 'object':
962 return f"'models.{self.type}'"
964 try:
965 return TYPE_MAPPING[self.type]
966 except KeyError as exc:
967 # TODO: handle this better
968 raise RuntimeError(
969 f'Could not parse {self.name} due to unknown type: {self.type}',
970 ) from exc
972 @property
973 def create_input_type(self) -> str:
974 if self.kind != 'object':
975 return self.python_type
977 if self.is_list:
978 return f"'{self.type}CreateManyNestedWithoutRelationsInput'"
980 return f"'{self.type}CreateNestedWithoutRelationsInput'"
982 @property
983 def where_input_type(self) -> str:
984 typ = self.type
985 if self.is_relational:
986 if self.is_list:
987 return f"'{typ}ListRelationFilter'"
988 return f"'{typ}RelationFilter'"
990 if self.is_list:
991 self.check_supported_scalar_list_type()
992 return f"'types.{typ}ListFilter'"
994 if typ in FILTER_TYPES:
995 if self.is_optional:
996 return f"Union[None, {self._actual_python_type}, 'types.{typ}Filter']"
997 return f"Union[{self._actual_python_type}, 'types.{typ}Filter']"
999 return self.python_type
1001 @property
1002 def where_aggregates_input_type(self) -> str:
1003 if self.is_relational: # pragma: no cover
1004 raise RuntimeError('This type is not valid for relational fields')
1006 typ = self.type
1007 if typ in FILTER_TYPES:
1008 return f"Union[{self._actual_python_type}, 'types.{typ}WithAggregatesFilter']"
1009 return self.python_type
1011 @property
1012 def relational_args_type(self) -> str:
1013 if self.is_list:
1014 return f'FindMany{self.type}Args'
1015 return f'{self.type}Args'
1017 @property
1018 def required_on_create(self) -> bool:
1019 return (
1020 self.is_required
1021 and not self.is_updated_at
1022 and not self.has_default_value
1023 and not self.relation_name
1024 and not self.is_list
1025 )
1027 @property
1028 def is_optional(self) -> bool:
1029 return not (self.is_required and not self.relation_name)
1031 @property
1032 def is_relational(self) -> bool:
1033 return self.relation_name is not None
1035 @property
1036 def is_atomic(self) -> bool:
1037 return self.type in ATOMIC_FIELD_TYPES
1039 @property
1040 def is_number(self) -> bool:
1041 return self.type in {'Int', 'BigInt', 'Float'}
1043 def maybe_optional(self, typ: str) -> str:
1044 """Wrap the given type string within `Optional` if applicable"""
1045 if self.is_required or self.is_relational:
1046 return typ
1047 return f'Optional[{typ}]'
1049 def get_update_input_type(self) -> str:
1050 if self.kind == 'object':
1051 if self.is_list:
1052 return f"'{self.type}UpdateManyWithoutRelationsInput'"
1053 return f"'{self.type}UpdateOneWithoutRelationsInput'"
1055 if self.is_list:
1056 self.check_supported_scalar_list_type()
1057 return f"'types.{self.type}ListUpdate'"
1059 if self.is_atomic:
1060 return f'Union[Atomic{self.type}Input, {self.python_type}]'
1062 return self.python_type
1064 def check_supported_scalar_list_type(self) -> None:
1065 if self.type not in FILTER_TYPES and self.kind != 'enum': # pragma: no branch
1066 raise UnsupportedListTypeError(self.type)
1068 def get_relational_model(self) -> Optional['Model']:
1069 if not self.is_relational: 1069 ↛ 1070line 1069 didn't jump to line 1070, because the condition on line 1069 was never true
1070 return None
1072 name = self.type
1073 for model in get_datamodel().models: 1073 ↛ 1076line 1073 didn't jump to line 1076, because the loop on line 1073 didn't complete
1074 if model.name == name:
1075 return model
1076 return None
1078 def get_corresponding_enum(self) -> Optional['Enum']:
1079 typ = self.type
1080 for enum in get_datamodel().enums:
1081 if enum.name == typ: 1081 ↛ 1080line 1081 didn't jump to line 1080, because the condition on line 1081 was never false
1082 return enum
1083 return None # pragma: no cover
1085 def get_sample_data(self, *, increment: bool = True) -> str:
1086 # returning the same data that was last sampled is useful
1087 # for documenting methods like upsert() where data is duplicated
1088 if not increment and self._last_sampled is not None:
1089 return self._last_sampled
1091 sampled = self._get_sample_data()
1092 if self.is_list:
1093 sampled = f'[{sampled}]'
1095 self._last_sampled = sampled
1096 return sampled
1098 def _get_sample_data(self) -> str:
1099 if self.is_relational: # pragma: no cover
1100 raise RuntimeError('Data sampling for relational fields not supported yet')
1102 if self.kind == 'enum':
1103 enum = self.get_corresponding_enum()
1104 assert enum is not None, self.type
1105 return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}'
1107 typ = self.type
1108 if typ == 'Boolean':
1109 return str(FAKER.boolean())
1110 elif typ == 'Int':
1111 return str(FAKER.integer())
1112 elif typ == 'String':
1113 return f"'{FAKER.string()}'"
1114 elif typ == 'Float':
1115 return f'{FAKER.integer()}.{FAKER.integer() // 10000}'
1116 elif typ == 'BigInt': # pragma: no cover
1117 return str(FAKER.integer() * 12)
1118 elif typ == 'DateTime':
1119 # TODO: random dates
1120 return 'datetime.datetime.utcnow()'
1121 elif typ == 'Json': 1121 ↛ 1122line 1121 didn't jump to line 1122, because the condition on line 1121 was never true
1122 return f"Json({{'{FAKER.string()}': True}})"
1123 elif typ == 'Bytes': 1123 ↛ 1125line 1123 didn't jump to line 1125, because the condition on line 1123 was never false
1124 return f"Base64.encode(b'{FAKER.string()}')"
1125 elif typ == 'Decimal':
1126 return f"Decimal('{FAKER.integer()}.{FAKER.integer() // 10000}')"
1127 else: # pragma: no cover
1128 raise RuntimeError(f'Sample data not supported for {typ} yet')
1131class DefaultValue(BaseModel):
1132 args: Any = None
1133 name: str
1136class _EmptyModel(BaseModel):
1137 if PYDANTIC_V2:
1138 model_config: ClassVar[ConfigDict] = ConfigDict(extra='forbid')
1139 elif not TYPE_CHECKING:
1141 class Config(BaseModel.Config):
1142 extra: pydantic.Extra = pydantic.Extra.forbid
1145class _ModelAllowAll(BaseModel):
1146 if PYDANTIC_V2:
1147 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow')
1148 elif not TYPE_CHECKING:
1150 class Config(BaseModel.Config):
1151 extra: pydantic.Extra = pydantic.Extra.allow
1154class PythonNames(BaseModel):
1155 def client_class(self, _for_async: bool) -> str:
1156 return 'Prisma'
1159class PythonData(GenericData[Config]):
1160 """Data class including the default Prisma Client Python config"""
1162 if not PYDANTIC_V2:
1164 class Config(BaseConfig):
1165 arbitrary_types_allowed: bool = True
1166 json_encoders: Dict[Type[Any], Any] = {
1167 Path: _pathlib_serializer,
1168 machinery.ModuleSpec: _module_spec_serializer,
1169 }
1170 keep_untouched: Tuple[Type[Any], ...] = (cached_property,)
1172 names: PythonNames = PythonNames()
1175class DefaultData(GenericData[_EmptyModel]):
1176 """Data class without any config options"""
1179# this has to be defined as a type alias instead of a class
1180# as its purpose is to signify that the data is config agnostic
1181AnyData = GenericData[Any]
1183model_rebuild(Enum)
1184model_rebuild(DMMF)
1185model_rebuild(GenericData)
1186model_rebuild(Field)
1187model_rebuild(Model)
1188model_rebuild(Datamodel)
1189model_rebuild(Generator)
1190model_rebuild(Datasource)
1193from .errors import (
1194 TemplateError,
1195 PartialTypeGeneratorError,
1196)
1197from .schema import Schema, ClientTypes