Coverage for src/prisma/_compat.py: 94%

138 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1from __future__ import annotations 

2 

3import os 

4import sys 

5from typing import TYPE_CHECKING, Any, TypeVar, Callable, cast 

6from asyncio import get_running_loop as get_running_loop 

7 

8import pydantic 

9from pydantic import BaseModel 

10from pydantic.fields import FieldInfo 

11 

12from .utils import make_optional 

13 

14_T = TypeVar('_T') 

15_ModelT = TypeVar('_ModelT', bound=BaseModel) 

16 

17 

18# Pydantic v2 compat 

19PYDANTIC_V2 = pydantic.VERSION.startswith('2.') 

20 

21# ---- validators ---- 

22 

23 

24def field_validator( 

25 __field: str, 

26 *fields: str, 

27 pre: bool = False, 

28 check_fields: bool | None = None, 

29 always: bool | None = None, 

30 allow_reuse: bool | None = None, 

31) -> Callable[[_T], _T]: 

32 if PYDANTIC_V2: 

33 return cast( # type: ignore[no-any-return] 

34 Any, 

35 pydantic.field_validator( 

36 __field, 

37 *fields, 

38 mode='before' if pre else 'after', 

39 check_fields=check_fields, 

40 ), 

41 ) 

42 

43 kwargs = {} 

44 if always is not None: 

45 kwargs['always'] = always 

46 if allow_reuse is not None: 

47 kwargs['allow_reuse'] = allow_reuse 

48 

49 return pydantic.validator(__field, *fields, pre=pre, **kwargs) # type: ignore 

50 

51 

52def root_validator( 

53 *__args: Any, 

54 pre: bool = False, 

55 skip_on_failure: bool = False, 

56 allow_reuse: bool = False, 

57) -> Callable[[_T], _T]: 

58 if PYDANTIC_V2: 

59 return pydantic.model_validator( # type: ignore 

60 mode='before' if pre else 'after', 

61 ) 

62 

63 return cast(Any, pydantic.root_validator)( # type: ignore[no-any-return] 

64 *__args, 

65 pre=pre, 

66 skip_on_failure=skip_on_failure, 

67 allow_reuse=allow_reuse, 

68 ) 

69 

70 

71if TYPE_CHECKING: 

72 BaseSettings = BaseModel 

73 BaseSettingsConfig = ( 

74 pydantic.BaseConfig # pyright: ignore[reportDeprecated] 

75 ) 

76 

77 class BaseConfig: ... 

78 

79 from pydantic import ( 

80 PlainSerializer as PlainSerializer, 

81 GetCoreSchemaHandler as GetCoreSchemaHandler, 

82 ) 

83 from pydantic_core import ( 

84 CoreSchema as CoreSchema, 

85 core_schema as core_schema, 

86 ) 

87 

88 class GenericModel(BaseModel): ... 

89 

90else: 

91 if PYDANTIC_V2: 

92 from pydantic import PlainSerializer, GetCoreSchemaHandler 

93 from pydantic_core import CoreSchema, core_schema 

94 else: 

95 core_schema = None 

96 CoreSchema = None 

97 GetCoreSchemaHandler = None 

98 PlainSerializer = None 

99 

100 if PYDANTIC_V2: 

101 GenericModel = BaseModel 

102 else: 

103 from pydantic.generics import GenericModel as PydanticGenericModel 

104 

105 class GenericModel(PydanticGenericModel, BaseModel): ... 

106 

107 if PYDANTIC_V2: 

108 from pydantic import model_validator 

109 

110 class BaseSettings(BaseModel): 

111 @model_validator(mode='before') 

112 def root_validator(cls, values: Any) -> Any: 

113 return _env_var_resolver(cls, values) 

114 

115 BaseSettingsConfig = None 

116 

117 BaseConfig = None 

118 

119 else: 

120 from pydantic import ( 

121 BaseConfig as BaseConfig, 

122 BaseSettings as BaseSettings, 

123 ) 

124 

125 BaseSettingsConfig = BaseSettings.Config 

126 

127 

128# v1 re-exports 

129if TYPE_CHECKING: 

130 from pydantic.v1 import Extra as Extra 

131 

132 def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 

133 ... 

134 

135 def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 

136 ... 

137 

138 def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 

139 ... 

140 

141 def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 

142 ... 

143 

144 def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 

145 ... 

146 

147else: 

148 if PYDANTIC_V2: 

149 from pydantic.v1 import Extra as Extra 

150 from pydantic.v1.typing import ( 

151 get_args as get_args, 

152 is_union as is_union, 

153 get_origin as get_origin, 

154 is_typeddict as is_typeddict, 

155 is_literal_type as is_literal_type, 

156 ) 

157 else: 

158 from pydantic import Extra as Extra 

159 from pydantic.typing import ( 

160 get_args as get_args, 

161 is_union as is_union, 

162 get_origin as get_origin, 

163 is_typeddict as is_typeddict, 

164 is_literal_type as is_literal_type, 

165 ) 

166 

167 

168# refactored config 

169if TYPE_CHECKING: 

170 from pydantic import ConfigDict as ConfigDict 

171else: 

172 if PYDANTIC_V2: 

173 from pydantic import ConfigDict 

174 else: 

175 ConfigDict = None 

176 

177 

178ENV_VAR_KEY = '$env' 

179 

180 

181# minimal re-implementation of BaseSettings for v2 

182def _env_var_resolver(model_cls: type[BaseModel], values: Any) -> dict[str, Any]: 

183 assert isinstance(values, dict) 

184 

185 for key, field_info in model_cls.model_fields.items(): 

186 env_var = _get_field_env_var(field_info, name=key) 

187 if not env_var: 187 ↛ 188line 187 didn't jump to line 188, because the condition on line 187 was never true

188 continue 

189 

190 assert isinstance(env_var, str) 

191 

192 # Note: we always want to prioritise the env var 

193 # over the value given due to how config loading works 

194 value = os.environ.get(env_var) 

195 if value is not None: 

196 values[key] = value 

197 

198 return values 

199 

200 

201def _get_field_env_var(field: FieldInfo, name: str) -> str | None: 

202 if not PYDANTIC_V2: 

203 return field.field_info.extra.get('env') # type: ignore 

204 

205 extra = field.json_schema_extra 

206 if not extra: 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true

207 return None 

208 

209 if callable(extra): 209 ↛ 210line 209 didn't jump to line 210, because the condition on line 209 was never true

210 raise RuntimeError(f'Unexpected json schema for field "{name}" is a function') 

211 

212 env = extra.get(ENV_VAR_KEY) 

213 if env and isinstance(env, str): 213 ↛ 216line 213 didn't jump to line 216, because the condition on line 213 was never false

214 return env 

215 

216 return None 

217 

218 

219def is_field_required(field: FieldInfo) -> bool: 

220 if PYDANTIC_V2: 

221 return field.is_required() 

222 return field.required # type: ignore 

223 

224 

225def model_fields(model: type[BaseModel]) -> dict[str, FieldInfo]: 

226 if PYDANTIC_V2: 

227 return model.model_fields 

228 return model.__fields__ # type: ignore 

229 

230 

231def model_field_type(field: FieldInfo) -> type | None: 

232 if PYDANTIC_V2: 

233 return field.annotation 

234 

235 return field.type_ # type: ignore 

236 

237 

238def model_copy(model: _ModelT, deep: bool = False) -> _ModelT: 

239 if PYDANTIC_V2: 239 ↛ 242line 239 didn't jump to line 242, because the condition on line 239 was never false

240 return model.model_copy(deep=deep) 

241 

242 return model.copy(deep=deep) # pyright: ignore[reportDeprecated] 

243 

244 

245def model_json( 

246 model: BaseModel, 

247 *, 

248 indent: int | None = None, 

249 exclude: set[str] | None = None, 

250) -> str: 

251 if PYDANTIC_V2: 

252 return model.model_dump_json(indent=indent, exclude=exclude) 

253 

254 return model.json( # pyright: ignore[reportDeprecated] 

255 indent=indent, 

256 exclude=exclude, 

257 ) 

258 

259 

260def model_dict( 

261 model: BaseModel, 

262 *, 

263 by_alias: bool = False, 

264 exclude: set[str] | None = None, 

265 exclude_unset: bool = False, 

266) -> dict[str, Any]: 

267 if PYDANTIC_V2: 

268 return model.model_dump( 

269 exclude_unset=exclude_unset, 

270 exclude=exclude, 

271 by_alias=by_alias, 

272 ) 

273 

274 return model.dict( # pyright: ignore[reportDeprecated] 

275 exclude=exclude, 

276 exclude_unset=exclude_unset, 

277 by_alias=by_alias, 

278 ) 

279 

280 

281def model_rebuild(model: type[BaseModel]) -> None: 

282 if PYDANTIC_V2: 

283 model.model_rebuild() 

284 else: 

285 model.update_forward_refs() # pyright: ignore[reportDeprecated] 

286 

287 

288def model_parse(model: type[_ModelT], obj: Any) -> _ModelT: 

289 if PYDANTIC_V2: 

290 return model.model_validate(obj) 

291 else: 

292 return model.parse_obj(obj) # pyright: ignore[reportDeprecated] 

293 

294 

295def model_parse_json(model: type[_ModelT], obj: str) -> _ModelT: 

296 if PYDANTIC_V2: 

297 return model.model_validate_json(obj) 

298 else: 

299 return model.parse_raw(obj) # pyright: ignore[reportDeprecated] 

300 

301 

302def model_json_schema(model: type[BaseModel]) -> dict[str, Any]: 

303 if PYDANTIC_V2: 

304 return model.model_json_schema() 

305 else: 

306 return model.schema() # pyright: ignore[reportDeprecated] 

307 

308 

309def Field(*, env: str | None = None, **extra: Any) -> Any: 

310 if PYDANTIC_V2: 

311 # we store environment variable metadata in $env 

312 # as a workaround to support BaseSettings behaviour ourselves 

313 # as we can't depend on pydantic-settings 

314 json_schema_extra = None 

315 if env: 

316 json_schema_extra = {ENV_VAR_KEY: env} 

317 

318 return pydantic.Field(**extra, json_schema_extra=json_schema_extra) # type: ignore 

319 

320 return pydantic.Field(**extra, env=env) # type: ignore 

321 

322 

323if sys.version_info[:2] < (3, 8): 

324 # cached_property doesn't define type hints so just ignore it 

325 # it is functionally equivalent to the standard property anyway 

326 if TYPE_CHECKING: 

327 cached_property = property 

328 else: 

329 from cached_property import cached_property as cached_property 

330else: 

331 from functools import cached_property as cached_property 

332 

333 

334if TYPE_CHECKING: 

335 import nodejs as _nodejs 

336 

337 nodejs = make_optional(_nodejs) 

338else: 

339 try: 

340 import nodejs 

341 except ImportError: 

342 nodejs = None 

343 

344 

345# Note: this shim is due to an inconsistency with string enums 

346# that was fixed in Python3.11, for reference see: 

347# - https://blog.pecar.me/python-enum#there-be-dragons 

348# - https://github.com/python/cpython/issues/100458 

349if TYPE_CHECKING: 

350 if sys.version_info >= (3, 11): 

351 from enum import StrEnum as StrEnum 

352 else: 

353 # Note: we have to define our own `StrEnum` 

354 # class as the backport we're using doesn't 

355 # define good types. 

356 from enum import Enum 

357 

358 class StrEnum(str, Enum): ... 

359else: 

360 if sys.version_info >= (3, 11): 

361 from enum import StrEnum as StrEnum 

362 else: 

363 from strenum import StrEnum as StrEnum 

364 

365 

366def removeprefix(string: str, prefix: str) -> str: 

367 if string.startswith(prefix): 

368 return string[len(prefix) :] 

369 return string