Coverage for src/prisma/generator/models.py: 94%
641 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1import os
2import sys
3import enum
4import textwrap
5import importlib
6from typing import (
7 TYPE_CHECKING,
8 Any,
9 Dict,
10 List,
11 Type,
12 Tuple,
13 Union,
14 Generic,
15 TypeVar,
16 ClassVar,
17 Iterable,
18 Iterator,
19 NoReturn,
20 Optional,
21 cast,
22)
23from keyword import iskeyword
24from pathlib import Path
25from importlib import util as importlib_util, machinery
26from itertools import chain
27from contextvars import ContextVar
28from importlib.abc import InspectLoader
29from typing_extensions import Annotated, override
31import click
32import pydantic
33from pydantic.fields import PrivateAttr
35from .. import config
36from .utils import Faker, Sampler, clean_multiline
37from ..utils import DEBUG_GENERATOR, assert_never
38from ..errors import UnsupportedListTypeError
39from .._compat import (
40 PYDANTIC_V2,
41 Field as FieldInfo,
42 BaseConfig,
43 ConfigDict,
44 BaseSettings,
45 GenericModel,
46 PlainSerializer,
47 BaseSettingsConfig,
48 model_dict,
49 model_parse,
50 model_rebuild,
51 root_validator,
52 cached_property,
53 field_validator,
54)
55from .._constants import QUERY_BUILDER_ALIASES
56from ._dsl_parser import parse_schema_dsl
58__all__ = (
59 'AnyData',
60 'PythonData',
61 'DefaultData',
62 'GenericData',
63)
65_ModelT = TypeVar('_ModelT', bound=pydantic.BaseModel)
67# NOTE: this does not represent all the data that is passed by prisma
69ATOMIC_FIELD_TYPES = ['Int', 'BigInt', 'Float']
71TYPE_MAPPING = {
72 'String': '_str',
73 'Bytes': "'fields.Base64'",
74 'DateTime': 'datetime.datetime',
75 'Boolean': '_bool',
76 'Int': '_int',
77 'Float': '_float',
78 'BigInt': '_int',
79 'Json': "'fields.Json'",
80 'Decimal': 'decimal.Decimal',
81}
82FILTER_TYPES = [
83 'String',
84 'Bytes',
85 'DateTime',
86 'Boolean',
87 'Int',
88 'BigInt',
89 'Float',
90 'Json',
91 'Decimal',
92]
93RECURSIVE_TYPE_DEPTH_WARNING = """Some types are disabled by default due to being incompatible with Mypy, it is highly recommended
94to use Pyright instead and configure Prisma Python to use recursive types. To re-enable certain types:"""
96RECURSIVE_TYPE_DEPTH_WARNING_DESC = """
97generator client {
98 provider = "prisma-client-py"
99 recursive_type_depth = -1
100}
102If you need to use Mypy, you can also disable this message by explicitly setting the default value:
104generator client {
105 provider = "prisma-client-py"
106 recursive_type_depth = 5
107}
109For more information see: https://prisma-client-py.readthedocs.io/en/stable/reference/limitations/#default-type-limitations
110"""
112FAKER: Faker = Faker()
115ConfigT = TypeVar('ConfigT', bound=pydantic.BaseModel)
117# Although we should just be able to access the config from the datamodel
118# we have to do some validation that requires access to the config, this is difficult
119# with heavily nested models as our current workaround only sets the datamodel context
120# post-validation meaning we cannot access it in validators. To get around this we have
121# a separate config context.
122# TODO: better solution
123data_ctx: ContextVar['AnyData'] = ContextVar('data_ctx')
124config_ctx: ContextVar['Config'] = ContextVar('config_ctx')
127def get_datamodel() -> 'Datamodel':
128 return data_ctx.get().dmmf.datamodel
131# typed to ensure the caller has to handle the cases where:
132# - a custom generator config is being used
133# - the config is invalid and therefore could not be set
134def get_config() -> Union[None, pydantic.BaseModel, 'Config']:
135 return config_ctx.get(None)
138def get_list_types() -> Iterable[Tuple[str, str]]:
139 # WARNING: do not edit this function without also editing Field.is_supported_scalar_list_type()
140 return chain(
141 ((t, TYPE_MAPPING[t]) for t in FILTER_TYPES),
142 ((enum.name, f"'enums.{enum.name}'") for enum in get_datamodel().enums),
143 )
146def sql_param(num: int = 1) -> str:
147 # TODO: add case for sqlserver
148 active_provider = data_ctx.get().datasources[0].active_provider
149 if active_provider == 'postgresql':
150 return f'${num}'
152 # TODO: test
153 if active_provider == 'mongodb': # pragma: no cover
154 raise RuntimeError('no-op')
156 # SQLite and MySQL use this style so just default to it
157 return '?'
160def raise_err(msg: str) -> NoReturn:
161 raise TemplateError(msg)
164def type_as_string(typ: str) -> str:
165 """Ensure a type string is wrapped with a string, e.g.
167 enums.Role -> 'enums.Role'
168 """
169 # TODO: use this function internally in this module
170 if not typ.startswith("'") and not typ.startswith('"'):
171 return f"'{typ}'"
172 return typ
175def format_documentation(doc: str, indent: int = 4) -> str:
176 """Format a schema comment by indenting nested lines, e.g.
178 '''Foo
179 Bar'''
181 Becomes
183 '''Foo
184 Bar
185 '''
186 """
187 if not doc: 187 ↛ 189line 187 didn't jump to line 189, because the condition on line 187 was never true
188 # empty string, nothing to do
189 return doc
191 prefix = ' ' * indent
192 first, *rest = doc.splitlines()
193 return '\n'.join(
194 [
195 first,
196 *[textwrap.indent(line, prefix) for line in rest],
197 prefix,
198 ]
199 )
202def _module_spec_serializer(spec: machinery.ModuleSpec) -> str:
203 assert spec.origin is not None, 'Cannot serialize module with no origin'
204 return spec.origin
207def _pathlib_serializer(path: Path) -> str:
208 return str(path.absolute())
211def _recursive_type_depth_factory() -> int:
212 click.echo(
213 click.style(
214 f'\n{RECURSIVE_TYPE_DEPTH_WARNING}',
215 fg='yellow',
216 )
217 )
218 click.echo(f'{RECURSIVE_TYPE_DEPTH_WARNING_DESC}\n')
219 return 5
222class BaseModel(pydantic.BaseModel):
223 if PYDANTIC_V2:
224 model_config: ClassVar[ConfigDict] = ConfigDict(
225 arbitrary_types_allowed=True,
226 ignored_types=(cached_property,),
227 )
228 else:
230 class Config(BaseConfig):
231 arbitrary_types_allowed: bool = True
232 json_encoders: Dict[Type[Any], Any] = {
233 Path: _pathlib_serializer,
234 machinery.ModuleSpec: _module_spec_serializer,
235 }
236 keep_untouched: Tuple[Type[Any], ...] = (cached_property,)
239class InterfaceChoices(str, enum.Enum):
240 sync = 'sync'
241 asyncio = 'asyncio'
244class EngineType(str, enum.Enum):
245 binary = 'binary'
246 library = 'library'
247 dataproxy = 'dataproxy'
249 @override
250 def __str__(self) -> str:
251 return self.value
254class Module(BaseModel):
255 if TYPE_CHECKING:
256 spec: machinery.ModuleSpec
257 else:
258 if PYDANTIC_V2:
259 spec: Annotated[
260 machinery.ModuleSpec,
261 PlainSerializer(lambda x: _module_spec_serializer(x), return_type=str),
262 ]
263 else:
264 spec: machinery.ModuleSpec
266 if PYDANTIC_V2:
267 model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
268 else:
270 class Config(BaseModel.Config):
271 arbitrary_types_allowed: bool = True
273 # for some reason this is needed in Pydantic v2
274 @root_validator(pre=True, skip_on_failure=True)
275 @classmethod
276 def partial_type_generator_converter(cls, values: object) -> Any:
277 if isinstance(values, str):
278 return {'spec': values}
279 return values
281 @field_validator('spec', pre=True, allow_reuse=True)
282 @classmethod
283 def spec_validator(cls, value: Optional[str]) -> machinery.ModuleSpec:
284 spec: Optional[machinery.ModuleSpec] = None
286 # TODO: this should really work based off of the schema path
287 # and this should suport checking just partial_types.py if we are in a `prisma` dir
288 if value is None:
289 value = 'prisma/partial_types.py'
291 path = Path.cwd().joinpath(value)
292 if path.exists():
293 spec = importlib_util.spec_from_file_location('prisma.partial_type_generator', value)
294 elif value.startswith('.'): 294 ↛ 295line 294 didn't jump to line 295, because the condition on line 294 was never true
295 raise ValueError(f'No file found at {value} and relative imports are not allowed.')
296 else:
297 try:
298 spec = importlib_util.find_spec(value)
299 except ModuleNotFoundError:
300 spec = None
302 if spec is None:
303 raise ValueError(f'Could not find a python file or module at {value}')
305 return spec
307 def run(self) -> None:
308 importlib.invalidate_caches()
309 mod = importlib_util.module_from_spec(self.spec)
310 loader = self.spec.loader
311 assert loader is not None, 'Expected an import loader to exist.'
312 assert isinstance(loader, InspectLoader), f'Cannot execute module from loader type: {type(loader)}'
314 try:
315 loader.exec_module(mod)
316 except Exception as exc:
317 raise PartialTypeGeneratorError() from exc
320class GenericData(GenericModel, Generic[ConfigT]):
321 """Root model for the data that prisma provides to the generator.
323 WARNING: only one instance of this class may exist at any given time and
324 instances should only be constructed using the Data.parse_obj() method
325 """
327 datamodel: str
328 version: str
329 generator: 'Generator[ConfigT]'
330 dmmf: 'DMMF' = FieldInfo(alias='dmmf')
331 schema_path: Path = FieldInfo(alias='schemaPath')
332 datasources: List['Datasource'] = FieldInfo(alias='datasources')
333 other_generators: List['Generator[_ModelAllowAll]'] = FieldInfo(alias='otherGenerators')
334 binary_paths: 'BinaryPaths' = FieldInfo(alias='binaryPaths', default_factory=lambda: BinaryPaths()) 334 ↛ exitline 334 didn't run the lambda on line 334
336 if PYDANTIC_V2:
338 @root_validator(pre=False)
339 def _set_ctx(self: _ModelT) -> _ModelT:
340 data_ctx.set(cast('GenericData[ConfigT]', self))
341 return self
343 else:
345 @classmethod
346 @override
347 def parse_obj(cls, obj: Any) -> 'GenericData[ConfigT]':
348 data = super().parse_obj(obj) # pyright: ignore[reportDeprecated]
349 data_ctx.set(data)
350 return data
352 def to_params(self) -> Dict[str, Any]:
353 """Get the parameters that should be sent to Jinja templates"""
354 params = vars(self)
355 params['type_schema'] = Schema.from_data(self)
356 params['client_types'] = ClientTypes.from_data(self)
358 # add utility functions
359 for func in [
360 sql_param,
361 raise_err,
362 type_as_string,
363 get_list_types,
364 clean_multiline,
365 format_documentation,
366 model_dict,
367 ]:
368 params[func.__name__] = func
370 return params
372 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
373 @classmethod
374 def validate_version(cls, values: Dict[Any, Any]) -> Dict[Any, Any]:
375 # TODO: test this
376 version = values.get('version')
377 if not DEBUG_GENERATOR and version != config.expected_engine_version: 377 ↛ 378line 377 didn't jump to line 378, because the condition on line 377 was never true
378 raise ValueError(
379 f'Prisma Client Python expected Prisma version: {config.expected_engine_version} '
380 f'but got: {version}\n'
381 ' If this is intentional, set the PRISMA_PY_DEBUG_GENERATOR environment '
382 'variable to 1 and try again.\n'
383 f' If you are using the Node CLI then you must switch to v{config.prisma_version}, e.g. '
384 f'npx prisma@{config.prisma_version} generate\n'
385 ' or generate the client using the Python CLI, e.g. python3 -m prisma generate'
386 )
387 return values
390class BinaryPaths(BaseModel):
391 """This class represents the paths to engine binaries.
393 Each property in this class is a mapping of platform name to absolute path, for example:
395 ```py
396 # This is what will be set on an M1 chip if there are no other `binaryTargets` set
397 binary_paths.query_engine == {
398 'darwin-arm64': '/Users/robert/.cache/prisma-python/binaries/3.13.0/efdf9b1183dddfd4258cd181a72125755215ab7b/node_modules/prisma/query-engine-darwin-arm64'
399 }
400 ```
402 This is only available if the generator explicitly requests them using the `requires_engines` manifest property.
403 """
405 query_engine: Dict[str, str] = FieldInfo(
406 default_factory=dict,
407 alias='queryEngine',
408 )
409 introspection_engine: Dict[str, str] = FieldInfo(
410 default_factory=dict,
411 alias='introspectionEngine',
412 )
413 migration_engine: Dict[str, str] = FieldInfo(
414 default_factory=dict,
415 alias='migrationEngine',
416 )
417 libquery_engine: Dict[str, str] = FieldInfo(
418 default_factory=dict,
419 alias='libqueryEngine',
420 )
421 prisma_format: Dict[str, str] = FieldInfo(
422 default_factory=dict,
423 alias='prismaFmt',
424 )
426 if PYDANTIC_V2:
427 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow')
428 else:
430 class Config(BaseModel.Config): # pyright: ignore[reportDeprecated]
431 extra: Any = (
432 pydantic.Extra.allow # pyright: ignore[reportDeprecated]
433 )
436class Datasource(BaseModel):
437 # TODO: provider enums
438 name: str
439 provider: str
440 active_provider: str = FieldInfo(alias='activeProvider')
441 url: 'OptionalValueFromEnvVar'
443 source_file_path: Optional[Path] = FieldInfo(alias='sourceFilePath')
446class Generator(GenericModel, Generic[ConfigT]):
447 name: str
448 output: 'ValueFromEnvVar'
449 provider: 'OptionalValueFromEnvVar'
450 config: ConfigT
451 binary_targets: List['ValueFromEnvVar'] = FieldInfo(alias='binaryTargets')
452 preview_features: List[str] = FieldInfo(alias='previewFeatures')
454 @field_validator('binary_targets')
455 @classmethod
456 def warn_binary_targets(cls, targets: List['ValueFromEnvVar']) -> List['ValueFromEnvVar']:
457 # Prisma by default sends one binary target which is the current platform.
458 if len(targets) > 1:
459 click.echo(
460 click.style(
461 'Warning: ' + 'The binaryTargets option is not officially supported by Prisma Client Python.',
462 fg='yellow',
463 ),
464 file=sys.stdout,
465 )
467 return targets
469 def has_preview_feature(self, feature: str) -> bool:
470 return feature in self.preview_features
473class ValueFromEnvVar(BaseModel):
474 value: str
475 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')
478class OptionalValueFromEnvVar(BaseModel):
479 value: Optional[str] = None
480 from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')
482 def resolve(self) -> str:
483 value = self.value
484 if value is not None:
485 return value
487 env_var = self.from_env_var
488 assert env_var is not None, 'from_env_var should not be None'
489 value = os.environ.get(env_var)
490 if value is None: 490 ↛ 491line 490 didn't jump to line 491, because the condition on line 490 was never true
491 raise RuntimeError(f'Environment variable not found: {env_var}')
493 return value
496class Config(BaseSettings):
497 """Custom generator config options."""
499 interface: InterfaceChoices = FieldInfo(default=InterfaceChoices.asyncio, env='PRISMA_PY_CONFIG_INTERFACE')
500 partial_type_generator: Optional[Module] = FieldInfo(default=None, env='PRISMA_PY_CONFIG_PARTIAL_TYPE_GENERATOR')
501 recursive_type_depth: int = FieldInfo(
502 default_factory=_recursive_type_depth_factory,
503 env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH',
504 )
505 engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE')
507 # this should be a list of experimental features
508 # https://github.com/prisma/prisma/issues/12442
509 enable_experimental_decimal: bool = FieldInfo(default=False, env='PRISMA_PY_CONFIG_ENABLE_EXPERIMENTAL_DECIMAL')
511 # this seems to be the only good method for setting the contextvar as
512 # we don't control the actual construction of the object like we do for
513 # the Data model.
514 # we do not expose this to type checkers so that the generated __init__
515 # signature is preserved.
516 if not TYPE_CHECKING:
518 def __init__(self, **kwargs: object) -> None:
519 super().__init__(**kwargs)
520 config_ctx.set(self)
522 if PYDANTIC_V2:
523 model_config: ClassVar[ConfigDict] = ConfigDict(
524 extra='forbid',
525 use_enum_values=True,
526 populate_by_name=True,
527 )
528 else:
529 if not TYPE_CHECKING:
531 class Config(BaseSettingsConfig):
532 extra: pydantic.Extra = pydantic.Extra.forbid
533 use_enum_values: bool = True
534 env_prefix: str = 'prisma_py_config_'
535 allow_population_by_field_name: bool = True
537 @classmethod
538 def customise_sources(cls, init_settings, env_settings, file_secret_settings):
539 # prioritise env settings over init settings
540 return env_settings, init_settings, file_secret_settings
542 @root_validator(pre=True, skip_on_failure=True)
543 @classmethod
544 def transform_engine_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
545 # prioritise env variable over schema option
546 engine_type = os.environ.get('PRISMA_CLIENT_ENGINE_TYPE')
547 if engine_type is None: 547 ↛ 551line 547 didn't jump to line 551, because the condition on line 547 was never false
548 engine_type = values.get('engineType')
550 # only add engine_type if it is present
551 if engine_type is not None:
552 values['engine_type'] = engine_type
553 values.pop('engineType', None)
555 return values
557 @root_validator(pre=True, skip_on_failure=True)
558 @classmethod
559 def removed_http_option_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
560 http = values.get('http')
561 if http is not None:
562 if http in {'aiohttp', 'httpx-async'}:
563 option = 'asyncio'
564 elif http in {'requests', 'httpx-sync'}:
565 option = 'sync'
566 else: # pragma: no cover
567 # invalid http option, let pydantic handle the error
568 return values
570 raise ValueError(
571 'The http option has been removed in favour of the interface option.\n'
572 ' Please remove the http option from your Prisma schema and replace it with:\n'
573 f' interface = "{option}"'
574 )
575 return values
577 if PYDANTIC_V2:
579 @root_validator(pre=True, skip_on_failure=True)
580 @classmethod
581 def partial_type_generator_converter(cls, values: Dict[str, Any]) -> Dict[str, Any]:
582 # ensure env resolving happens
583 values = cast(Dict[str, Any], cls.root_validator(values)) # type: ignore
585 value = values.get('partial_type_generator')
587 try:
588 values['partial_type_generator'] = Module(
589 spec=value # pyright: ignore[reportArgumentType]
590 )
591 except ValueError:
592 if value is None:
593 # no config value passed and the default location was not found
594 return values
595 raise
597 return values
599 else:
601 @field_validator('partial_type_generator', pre=True, always=True, allow_reuse=True)
602 @classmethod
603 def _partial_type_generator_converter(cls, value: Optional[str]) -> Optional[Module]:
604 try:
605 return Module(
606 spec=value # pyright: ignore[reportArgumentType]
607 )
608 except ValueError:
609 if value is None:
610 # no config value passed and the default location was not found
611 return None
612 raise
614 @field_validator('recursive_type_depth', always=True, allow_reuse=True)
615 @classmethod
616 def recursive_type_depth_validator(cls, value: int) -> int:
617 if value < -1 or value in {0, 1}:
618 raise ValueError('Value must equal -1 or be greater than 1.')
619 return value
621 @field_validator('engine_type', always=True, allow_reuse=True)
622 @classmethod
623 def engine_type_validator(cls, value: EngineType) -> EngineType:
624 if value == EngineType.binary:
625 return value
626 elif value == EngineType.dataproxy: # pragma: no cover
627 raise ValueError('Prisma Client Python does not support the Prisma Data Proxy yet.')
628 elif value == EngineType.library: # pragma: no cover
629 raise ValueError('Prisma Client Python does not support native engine bindings yet.')
630 else: # pragma: no cover
631 assert_never(value)
634class DMMFEnumType(BaseModel):
635 name: str
636 values: List[object]
639class DMMFEnumTypes(BaseModel):
640 prisma: List[DMMFEnumType]
643class PrismaSchema(BaseModel):
644 enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes')
647class DMMF(BaseModel):
648 datamodel: 'Datamodel'
649 prisma_schema: PrismaSchema = FieldInfo(alias='schema')
652class Datamodel(BaseModel):
653 enums: List['Enum']
654 models: List['Model']
656 # not implemented yet
657 types: List[object]
659 @field_validator('types')
660 @classmethod
661 def no_composite_types_validator(cls, types: List[object]) -> object:
662 if types:
663 raise ValueError(
664 'Composite types are not supported yet. Please indicate you need this here: https://github.com/RobertCraigie/prisma-client-py/issues/314'
665 )
667 return types
670class Enum(BaseModel):
671 name: str
672 db_name: Optional[str] = FieldInfo(alias='dbName')
673 values: List['EnumValue']
676class EnumValue(BaseModel):
677 name: str
678 db_name: Optional[str] = FieldInfo(alias='dbName')
681class ModelExtension(BaseModel):
682 instance_name: Optional[str] = None
684 @field_validator('instance_name')
685 @classmethod
686 def instance_name_validator(cls, name: Optional[str]) -> Optional[str]:
687 if not name: 687 ↛ 688line 687 didn't jump to line 688, because the condition on line 687 was never true
688 return name
690 if not name.isidentifier():
691 raise ValueError(f'Custom Model instance_name "{name}" is not a valid Python identifier')
693 return name
696class Model(BaseModel):
697 name: str
698 documentation: Optional[str] = None
699 db_name: Optional[str] = FieldInfo(alias='dbName')
700 is_generated: bool = FieldInfo(alias='isGenerated')
701 compound_primary_key: Optional['PrimaryKey'] = FieldInfo(alias='primaryKey')
702 unique_indexes: List['UniqueIndex'] = FieldInfo(alias='uniqueIndexes')
703 all_fields: List['Field'] = FieldInfo(alias='fields')
705 # stores the parsed DSL, not an actual field defined by prisma
706 extension: Optional[ModelExtension] = None
708 _sampler: Sampler = PrivateAttr()
710 def __init__(self, **data: Any) -> None:
711 super().__init__(**data)
712 self._sampler = Sampler(self)
714 @root_validator(pre=True, allow_reuse=True)
715 @classmethod
716 def validate_dsl_extension(cls, values: Dict[Any, Any]) -> Dict[Any, Any]:
717 documentation = values.get('documentation')
718 if not documentation:
719 return values
721 parsed = parse_schema_dsl(documentation)
722 if parsed['type'] == 'invalid': 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true
723 raise ValueError(parsed['error'])
725 if parsed['type'] == 'ok': 725 ↛ 728line 725 didn't jump to line 728, because the condition on line 725 was never false
726 values['extension'] = model_parse(ModelExtension, parsed['value']['arguments'])
728 return values
730 @field_validator('name')
731 @classmethod
732 def name_validator(cls, name: str) -> str:
733 if iskeyword(name):
734 raise ValueError(
735 f'Model name "{name}" shadows a Python keyword; '
736 f'use a different model name with \'@@map("{name}")\'.'
737 )
739 if iskeyword(name.lower()):
740 raise ValueError(
741 f'Model name "{name}" results in a client property that shadows a Python keyword; '
742 f'use a different model name with \'@@map("{name}")\'.'
743 )
745 return name
747 @property
748 def related_models(self) -> Iterator['Model']:
749 models = get_datamodel().models
750 for field in self.relational_fields:
751 for model in models:
752 if field.type == model.name:
753 yield model
755 @property
756 def relational_fields(self) -> Iterator['Field']:
757 for field in self.all_fields:
758 if field.is_relational:
759 yield field
761 @property
762 def scalar_fields(self) -> Iterator['Field']:
763 for field in self.all_fields:
764 if not field.is_relational:
765 yield field
767 @property
768 def atomic_fields(self) -> Iterator['Field']:
769 for field in self.all_fields:
770 if field.type in ATOMIC_FIELD_TYPES:
771 yield field
773 @property
774 def required_array_fields(self) -> Iterator['Field']:
775 for field in self.all_fields:
776 if field.is_list and not field.relation_name and field.is_required:
777 yield field
779 # TODO: support combined unique constraints
780 @cached_property
781 def id_field(self) -> Optional['Field']:
782 """Find a field that can be passed to the model's `WhereUnique` filter"""
783 for field in self.scalar_fields: # pragma: no branch
784 if field.is_id or field.is_unique:
785 return field
786 return None
788 @property
789 def has_relational_fields(self) -> bool:
790 try:
791 next(self.relational_fields)
792 except StopIteration:
793 return False
794 else:
795 return True
797 @property
798 def instance_name(self) -> str:
799 """The name of this model in the generated client class, e.g.
801 `User` -> `Prisma().user`
802 """
803 if self.extension and self.extension.instance_name:
804 return self.extension.instance_name
806 return self.name.lower()
808 @property
809 def plural_name(self) -> str:
810 name = self.instance_name
811 if name.endswith('s'):
812 return name
813 return f'{name}s'
815 def resolve_field(self, name: str) -> 'Field':
816 for field in self.all_fields: 816 ↛ 820line 816 didn't jump to line 820, because the loop on line 816 didn't complete
817 if field.name == name:
818 return field
820 raise LookupError(f'Could not find a field with name: {name}')
822 def sampler(self) -> Sampler:
823 return self._sampler
826class Constraint(BaseModel):
827 name: str
828 fields: List[str]
830 @root_validator(pre=True, allow_reuse=True, skip_on_failure=True)
831 @classmethod
832 def resolve_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
833 name = values.get('name')
834 if isinstance(name, str):
835 return values
837 values['name'] = '_'.join(values['fields'])
838 return values
841class PrimaryKey(Constraint):
842 pass
845class UniqueIndex(Constraint):
846 pass
849class Field(BaseModel):
850 name: str
851 documentation: Optional[str] = None
853 # TODO: switch to enums
854 kind: str
855 type: str
857 is_id: bool = FieldInfo(alias='isId')
858 is_list: bool = FieldInfo(alias='isList')
859 is_unique: bool = FieldInfo(alias='isUnique')
860 is_required: bool = FieldInfo(alias='isRequired')
861 is_read_only: bool = FieldInfo(alias='isReadOnly')
862 is_generated: bool = FieldInfo(alias='isGenerated')
863 is_updated_at: bool = FieldInfo(alias='isUpdatedAt')
865 default: Optional[Union['DefaultValue', object, List[object]]] = None
866 has_default_value: bool = FieldInfo(alias='hasDefaultValue')
868 relation_name: Optional[str] = FieldInfo(alias='relationName', default=None)
869 relation_on_delete: Optional[str] = FieldInfo(alias='relationOnDelete', default=None)
870 relation_to_fields: Optional[List[str]] = FieldInfo(
871 alias='relationToFields',
872 default=None,
873 )
874 relation_from_fields: Optional[List[str]] = FieldInfo(
875 alias='relationFromFields',
876 default=None,
877 )
879 _last_sampled: Optional[str] = PrivateAttr()
881 @root_validator(pre=True, skip_on_failure=True)
882 @classmethod
883 def scalar_type_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]:
884 kind = values.get('kind')
885 type_ = values.get('type')
887 if kind == 'scalar':
888 if type_ is not None and type_ not in TYPE_MAPPING: 888 ↛ 889line 888 didn't jump to line 889, because the condition on line 888 was never true
889 raise ValueError(f'Unsupported scalar field type: {type_}')
891 return values
893 @field_validator('type')
894 @classmethod
895 def experimental_decimal_validator(cls, typ: str) -> str:
896 if typ == 'Decimal':
897 config = get_config()
899 # skip validating the experimental flag if we are
900 # being called from a custom generator
901 if isinstance(config, Config) and not config.enable_experimental_decimal:
902 raise ValueError(
903 'Support for the Decimal type is experimental\n'
904 ' As such you must set the `enable_experimental_decimal` config flag to true\n'
905 ' for more information see: https://github.com/RobertCraigie/prisma-client-py/issues/106'
906 )
908 return typ
910 @field_validator('name')
911 @classmethod
912 def name_validator(cls, name: str) -> str:
913 if getattr(BaseModel, name, None):
914 raise ValueError(
915 f'Field name "{name}" shadows a BaseModel attribute; '
916 f'use a different field name with \'@map("{name}")\'.'
917 )
919 if iskeyword(name):
920 raise ValueError(
921 f'Field name "{name}" shadows a Python keyword; ' f'use a different field name with \'@map("{name}")\'.'
922 )
924 if name == 'prisma':
925 raise ValueError(
926 'Field name "prisma" shadows a Prisma Client Python method; '
927 'use a different field name with \'@map("prisma")\'.'
928 )
930 if name in QUERY_BUILDER_ALIASES:
931 raise ValueError(
932 f'Field name "{name}" shadows an internal keyword; '
933 f'use a different field name with \'@map("{name}")\''
934 )
936 return name
938 # TODO: cache the properties
939 @property
940 def python_type(self) -> str:
941 type_ = self._actual_python_type
942 if self.is_list:
943 return f'List[{type_}]'
944 return type_
946 @property
947 def python_type_as_string(self) -> str:
948 type_ = self._actual_python_type
949 if self.is_list:
950 type_ = type_.replace("'", "\\'")
951 return f"'List[{type_}]'"
953 if not type_.startswith("'"):
954 type_ = f"'{type_}'"
956 return type_
958 @property
959 def _actual_python_type(self) -> str:
960 if self.kind == 'enum':
961 return f"'enums.{self.type}'"
963 if self.kind == 'object':
964 return f"'models.{self.type}'"
966 try:
967 return TYPE_MAPPING[self.type]
968 except KeyError as exc:
969 # TODO: handle this better
970 raise RuntimeError(
971 f'Could not parse {self.name} due to unknown type: {self.type}',
972 ) from exc
974 @property
975 def create_input_type(self) -> str:
976 if self.kind != 'object':
977 return self.python_type
979 if self.is_list:
980 return f"'{self.type}CreateManyNestedWithoutRelationsInput'"
982 return f"'{self.type}CreateNestedWithoutRelationsInput'"
984 @property
985 def where_input_type(self) -> str:
986 typ = self.type
987 if self.is_relational:
988 if self.is_list:
989 return f"'{typ}ListRelationFilter'"
990 return f"'{typ}RelationFilter'"
992 if self.is_list:
993 self.check_supported_scalar_list_type()
994 return f"'types.{typ}ListFilter'"
996 if typ in FILTER_TYPES:
997 if self.is_optional:
998 return f"Union[None, {self._actual_python_type}, 'types.{typ}Filter']"
999 return f"Union[{self._actual_python_type}, 'types.{typ}Filter']"
1001 return self.python_type
1003 @property
1004 def where_aggregates_input_type(self) -> str:
1005 if self.is_relational: # pragma: no cover
1006 raise RuntimeError('This type is not valid for relational fields')
1008 typ = self.type
1009 if typ in FILTER_TYPES:
1010 return f"Union[{self._actual_python_type}, 'types.{typ}WithAggregatesFilter']"
1011 return self.python_type
1013 @property
1014 def relational_args_type(self) -> str:
1015 if self.is_list:
1016 return f'FindMany{self.type}Args'
1017 return f'{self.type}Args'
1019 @property
1020 def required_on_create(self) -> bool:
1021 return (
1022 self.is_required
1023 and not self.is_updated_at
1024 and not self.has_default_value
1025 and not self.relation_name
1026 and not self.is_list
1027 )
1029 @property
1030 def is_optional(self) -> bool:
1031 return not (self.is_required and not self.relation_name)
1033 @property
1034 def is_relational(self) -> bool:
1035 return self.relation_name is not None
1037 @property
1038 def is_atomic(self) -> bool:
1039 return self.type in ATOMIC_FIELD_TYPES
1041 @property
1042 def is_number(self) -> bool:
1043 return self.type in {'Int', 'BigInt', 'Float'}
1045 def maybe_optional(self, typ: str) -> str:
1046 """Wrap the given type string within `Optional` if applicable"""
1047 if self.is_required or self.is_relational:
1048 return typ
1049 return f'Optional[{typ}]'
1051 def get_update_input_type(self) -> str:
1052 if self.kind == 'object':
1053 if self.is_list:
1054 return f"'{self.type}UpdateManyWithoutRelationsInput'"
1055 return f"'{self.type}UpdateOneWithoutRelationsInput'"
1057 if self.is_list:
1058 self.check_supported_scalar_list_type()
1059 return f"'types.{self.type}ListUpdate'"
1061 if self.is_atomic:
1062 return f'Union[Atomic{self.type}Input, {self.python_type}]'
1064 return self.python_type
1066 def check_supported_scalar_list_type(self) -> None:
1067 if self.type not in FILTER_TYPES and self.kind != 'enum': # pragma: no branch
1068 raise UnsupportedListTypeError(self.type)
1070 def get_relational_model(self) -> Optional['Model']:
1071 if not self.is_relational: 1071 ↛ 1072line 1071 didn't jump to line 1072, because the condition on line 1071 was never true
1072 return None
1074 name = self.type
1075 for model in get_datamodel().models: 1075 ↛ 1078line 1075 didn't jump to line 1078, because the loop on line 1075 didn't complete
1076 if model.name == name:
1077 return model
1078 return None
1080 def get_corresponding_enum(self) -> Optional['Enum']:
1081 typ = self.type
1082 for enum in get_datamodel().enums:
1083 if enum.name == typ: 1083 ↛ 1082line 1083 didn't jump to line 1082, because the condition on line 1083 was never false
1084 return enum
1085 return None # pragma: no cover
1087 def get_sample_data(self, *, increment: bool = True) -> str:
1088 # returning the same data that was last sampled is useful
1089 # for documenting methods like upsert() where data is duplicated
1090 if not increment and self._last_sampled is not None:
1091 return self._last_sampled
1093 sampled = self._get_sample_data()
1094 if self.is_list:
1095 sampled = f'[{sampled}]'
1097 self._last_sampled = sampled
1098 return sampled
1100 def _get_sample_data(self) -> str:
1101 if self.is_relational: # pragma: no cover
1102 raise RuntimeError('Data sampling for relational fields not supported yet')
1104 if self.kind == 'enum':
1105 enum = self.get_corresponding_enum()
1106 assert enum is not None, self.type
1107 return f'enums.{enum.name}.{FAKER.from_list(enum.values).name}'
1109 typ = self.type
1110 if typ == 'Boolean':
1111 return str(FAKER.boolean())
1112 elif typ == 'Int':
1113 return str(FAKER.integer())
1114 elif typ == 'String':
1115 return f"'{FAKER.string()}'"
1116 elif typ == 'Float':
1117 return f'{FAKER.integer()}.{FAKER.integer() // 10000}'
1118 elif typ == 'BigInt': # pragma: no cover
1119 return str(FAKER.integer() * 12)
1120 elif typ == 'DateTime':
1121 # TODO: random dates
1122 return 'datetime.datetime.utcnow()'
1123 elif typ == 'Json': 1123 ↛ 1124line 1123 didn't jump to line 1124, because the condition on line 1123 was never true
1124 return f"Json({{'{FAKER.string()}': True}})"
1125 elif typ == 'Bytes': 1125 ↛ 1127line 1125 didn't jump to line 1127, because the condition on line 1125 was never false
1126 return f"Base64.encode(b'{FAKER.string()}')"
1127 elif typ == 'Decimal':
1128 return f"Decimal('{FAKER.integer()}.{FAKER.integer() // 10000}')"
1129 else: # pragma: no cover
1130 raise RuntimeError(f'Sample data not supported for {typ} yet')
1133class DefaultValue(BaseModel):
1134 args: Any = None
1135 name: str
1138class _EmptyModel(BaseModel):
1139 if PYDANTIC_V2:
1140 model_config: ClassVar[ConfigDict] = ConfigDict(extra='forbid')
1141 elif not TYPE_CHECKING:
1143 class Config(BaseModel.Config):
1144 extra: pydantic.Extra = pydantic.Extra.forbid
1147class _ModelAllowAll(BaseModel):
1148 if PYDANTIC_V2:
1149 model_config: ClassVar[ConfigDict] = ConfigDict(extra='allow')
1150 elif not TYPE_CHECKING:
1152 class Config(BaseModel.Config):
1153 extra: pydantic.Extra = pydantic.Extra.allow
1156class PythonNames(BaseModel):
1157 def client_class(self, _for_async: bool) -> str:
1158 return 'Prisma'
1161class PythonData(GenericData[Config]):
1162 """Data class including the default Prisma Client Python config"""
1164 if not PYDANTIC_V2:
1166 class Config(BaseConfig):
1167 arbitrary_types_allowed: bool = True
1168 json_encoders: Dict[Type[Any], Any] = {
1169 Path: _pathlib_serializer,
1170 machinery.ModuleSpec: _module_spec_serializer,
1171 }
1172 keep_untouched: Tuple[Type[Any], ...] = (cached_property,)
1174 names: PythonNames = PythonNames()
1177class DefaultData(GenericData[_EmptyModel]):
1178 """Data class without any config options"""
1181# this has to be defined as a type alias instead of a class
1182# as its purpose is to signify that the data is config agnostic
1183AnyData = GenericData[Any]
1185model_rebuild(Enum)
1186model_rebuild(DMMF)
1187model_rebuild(GenericData)
1188model_rebuild(Field)
1189model_rebuild(Model)
1190model_rebuild(Datamodel)
1191model_rebuild(Generator)
1192model_rebuild(Datasource)
1195from .errors import (
1196 TemplateError,
1197 PartialTypeGeneratorError,
1198)
1199from .schema import Schema, ClientTypes