Coverage for src/prisma/_compat.py: 94%
138 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1from __future__ import annotations
3import os
4import sys
5from typing import TYPE_CHECKING, Any, TypeVar, Callable, cast
6from asyncio import get_running_loop as get_running_loop
8import pydantic
9from pydantic import BaseModel
10from pydantic.fields import FieldInfo
12from .utils import make_optional
14_T = TypeVar('_T')
15_ModelT = TypeVar('_ModelT', bound=BaseModel)
18# Pydantic v2 compat
19PYDANTIC_V2 = pydantic.VERSION.startswith('2.')
21# ---- validators ----
24def field_validator(
25 __field: str,
26 *fields: str,
27 pre: bool = False,
28 check_fields: bool | None = None,
29 always: bool | None = None,
30 allow_reuse: bool | None = None,
31) -> Callable[[_T], _T]:
32 if PYDANTIC_V2:
33 return cast( # type: ignore[no-any-return]
34 Any,
35 pydantic.field_validator(
36 __field,
37 *fields,
38 mode='before' if pre else 'after',
39 check_fields=check_fields,
40 ),
41 )
43 kwargs = {}
44 if always is not None:
45 kwargs['always'] = always
46 if allow_reuse is not None:
47 kwargs['allow_reuse'] = allow_reuse
49 return pydantic.validator(__field, *fields, pre=pre, **kwargs) # type: ignore
52def root_validator(
53 *__args: Any,
54 pre: bool = False,
55 skip_on_failure: bool = False,
56 allow_reuse: bool = False,
57) -> Callable[[_T], _T]:
58 if PYDANTIC_V2:
59 return pydantic.model_validator( # type: ignore
60 mode='before' if pre else 'after',
61 )
63 return cast(Any, pydantic.root_validator)( # type: ignore[no-any-return]
64 *__args,
65 pre=pre,
66 skip_on_failure=skip_on_failure,
67 allow_reuse=allow_reuse,
68 )
71if TYPE_CHECKING:
72 BaseSettings = BaseModel
73 BaseSettingsConfig = (
74 pydantic.BaseConfig # pyright: ignore[reportDeprecated]
75 )
77 class BaseConfig: ...
79 from pydantic import (
80 PlainSerializer as PlainSerializer,
81 GetCoreSchemaHandler as GetCoreSchemaHandler,
82 )
83 from pydantic_core import (
84 CoreSchema as CoreSchema,
85 core_schema as core_schema,
86 )
88 class GenericModel(BaseModel): ...
90else:
91 if PYDANTIC_V2:
92 from pydantic import PlainSerializer, GetCoreSchemaHandler
93 from pydantic_core import CoreSchema, core_schema
94 else:
95 core_schema = None
96 CoreSchema = None
97 GetCoreSchemaHandler = None
98 PlainSerializer = None
100 if PYDANTIC_V2:
101 GenericModel = BaseModel
102 else:
103 from pydantic.generics import GenericModel as PydanticGenericModel
105 class GenericModel(PydanticGenericModel, BaseModel): ...
107 if PYDANTIC_V2:
108 from pydantic import model_validator
110 class BaseSettings(BaseModel):
111 @model_validator(mode='before')
112 def root_validator(cls, values: Any) -> Any:
113 return _env_var_resolver(cls, values)
115 BaseSettingsConfig = None
117 BaseConfig = None
119 else:
120 from pydantic import (
121 BaseConfig as BaseConfig,
122 BaseSettings as BaseSettings,
123 )
125 BaseSettingsConfig = BaseSettings.Config
128# v1 re-exports
129if TYPE_CHECKING:
130 from pydantic.v1 import Extra as Extra
132 def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
133 ...
135 def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001
136 ...
138 def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001
139 ...
141 def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001
142 ...
144 def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001
145 ...
147else:
148 if PYDANTIC_V2:
149 from pydantic.v1 import Extra as Extra
150 from pydantic.v1.typing import (
151 get_args as get_args,
152 is_union as is_union,
153 get_origin as get_origin,
154 is_typeddict as is_typeddict,
155 is_literal_type as is_literal_type,
156 )
157 else:
158 from pydantic import Extra as Extra
159 from pydantic.typing import (
160 get_args as get_args,
161 is_union as is_union,
162 get_origin as get_origin,
163 is_typeddict as is_typeddict,
164 is_literal_type as is_literal_type,
165 )
168# refactored config
169if TYPE_CHECKING:
170 from pydantic import ConfigDict as ConfigDict
171else:
172 if PYDANTIC_V2:
173 from pydantic import ConfigDict
174 else:
175 ConfigDict = None
178ENV_VAR_KEY = '$env'
181# minimal re-implementation of BaseSettings for v2
182def _env_var_resolver(model_cls: type[BaseModel], values: Any) -> dict[str, Any]:
183 assert isinstance(values, dict)
185 for key, field_info in model_cls.model_fields.items():
186 env_var = _get_field_env_var(field_info, name=key)
187 if not env_var: 187 ↛ 188line 187 didn't jump to line 188, because the condition on line 187 was never true
188 continue
190 assert isinstance(env_var, str)
192 # Note: we always want to prioritise the env var
193 # over the value given due to how config loading works
194 value = os.environ.get(env_var)
195 if value is not None:
196 values[key] = value
198 return values
201def _get_field_env_var(field: FieldInfo, name: str) -> str | None:
202 if not PYDANTIC_V2:
203 return field.field_info.extra.get('env') # type: ignore
205 extra = field.json_schema_extra
206 if not extra: 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true
207 return None
209 if callable(extra): 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true
210 raise RuntimeError(f'Unexpected json schema for field "{name}" is a function')
212 env = extra.get(ENV_VAR_KEY)
213 if env and isinstance(env, str): 213 ↛ 216line 213 didn't jump to line 216, because the condition on line 213 was never false
214 return env
216 return None
219def is_field_required(field: FieldInfo) -> bool:
220 if PYDANTIC_V2:
221 return field.is_required()
222 return field.required # type: ignore
225def model_fields(model: type[BaseModel]) -> dict[str, FieldInfo]:
226 if PYDANTIC_V2:
227 return model.model_fields
228 return model.__fields__ # type: ignore
231def model_field_type(field: FieldInfo) -> type | None:
232 if PYDANTIC_V2:
233 return field.annotation
235 return field.type_ # type: ignore
238def model_copy(model: _ModelT, deep: bool = False) -> _ModelT:
239 if PYDANTIC_V2: 239 ↛ 242line 239 didn't jump to line 242, because the condition on line 239 was never false
240 return model.model_copy(deep=deep)
242 return model.copy(deep=deep) # pyright: ignore[reportDeprecated]
245def model_json(
246 model: BaseModel,
247 *,
248 indent: int | None = None,
249 exclude: set[str] | None = None,
250) -> str:
251 if PYDANTIC_V2:
252 return model.model_dump_json(indent=indent, exclude=exclude)
254 return model.json( # pyright: ignore[reportDeprecated]
255 indent=indent,
256 exclude=exclude,
257 )
260def model_dict(
261 model: BaseModel,
262 *,
263 by_alias: bool = False,
264 exclude: set[str] | None = None,
265 exclude_unset: bool = False,
266) -> dict[str, Any]:
267 if PYDANTIC_V2:
268 return model.model_dump(
269 exclude_unset=exclude_unset,
270 exclude=exclude,
271 by_alias=by_alias,
272 )
274 return model.dict( # pyright: ignore[reportDeprecated]
275 exclude=exclude,
276 exclude_unset=exclude_unset,
277 by_alias=by_alias,
278 )
281def model_rebuild(model: type[BaseModel]) -> None:
282 if PYDANTIC_V2:
283 model.model_rebuild()
284 else:
285 model.update_forward_refs() # pyright: ignore[reportDeprecated]
288def model_parse(model: type[_ModelT], obj: Any) -> _ModelT:
289 if PYDANTIC_V2:
290 return model.model_validate(obj)
291 else:
292 return model.parse_obj(obj) # pyright: ignore[reportDeprecated]
295def model_parse_json(model: type[_ModelT], obj: str) -> _ModelT:
296 if PYDANTIC_V2:
297 return model.model_validate_json(obj)
298 else:
299 return model.parse_raw(obj) # pyright: ignore[reportDeprecated]
302def model_json_schema(model: type[BaseModel]) -> dict[str, Any]:
303 if PYDANTIC_V2:
304 return model.model_json_schema()
305 else:
306 return model.schema() # pyright: ignore[reportDeprecated]
309def Field(*, env: str | None = None, **extra: Any) -> Any:
310 if PYDANTIC_V2:
311 # we store environment variable metadata in $env
312 # as a workaround to support BaseSettings behaviour ourselves
313 # as we can't depend on pydantic-settings
314 json_schema_extra = None
315 if env:
316 json_schema_extra = {ENV_VAR_KEY: env}
318 return pydantic.Field(**extra, json_schema_extra=json_schema_extra) # type: ignore
320 return pydantic.Field(**extra, env=env) # type: ignore
323if sys.version_info[:2] < (3, 8):
324 # cached_property doesn't define type hints so just ignore it
325 # it is functionally equivalent to the standard property anyway
326 if TYPE_CHECKING:
327 cached_property = property
328 else:
329 from cached_property import cached_property as cached_property
330else:
331 from functools import cached_property as cached_property
334if TYPE_CHECKING:
335 import nodejs as _nodejs
337 nodejs = make_optional(_nodejs)
338else:
339 try:
340 import nodejs
341 except ImportError:
342 nodejs = None
345# Note: this shim is due to an inconsistency with string enums
346# that was fixed in Python3.11, for reference see:
347# - https://blog.pecar.me/python-enum#there-be-dragons
348# - https://github.com/python/cpython/issues/100458
349if TYPE_CHECKING:
350 if sys.version_info >= (3, 11):
351 from enum import StrEnum as StrEnum
352 else:
353 # Note: we have to define our own `StrEnum`
354 # class as the backport we're using doesn't
355 # define good types.
356 from enum import Enum
358 class StrEnum(str, Enum): ...
359else:
360 if sys.version_info >= (3, 11):
361 from enum import StrEnum as StrEnum
362 else:
363 from strenum import StrEnum as StrEnum
366def removeprefix(string: str, prefix: str) -> str:
367 if string.startswith(prefix):
368 return string[len(prefix) :]
369 return string