Coverage for src/prisma/generator/utils.py: 96%

84 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-28 15:17 +0000

1from __future__ import annotations 

2 

3import os 

4import re 

5import shutil 

6from typing import TYPE_CHECKING, Any, Dict, List, Union, TypeVar, Iterator 

7from pathlib import Path 

8from textwrap import dedent 

9 

10from ..utils import monkeypatch 

11 

12if TYPE_CHECKING: 

13 from .models import Field, Model 

14 

15 

16T = TypeVar('T') 

17 

18# we have to use a mapping outside of the `Sampler` class 

19# to avoid https://github.com/RobertCraigie/prisma-client-py/issues/402 

20SAMPLER_ITER_MAPPING: 'Dict[str, Iterator[Field]]' = {} 

21 

22 

23class Faker: 

24 """Pseudo-random re-playable data. 

25 

26 Seeds are generated using a linear congruential generator, inspired by: 

27 https://stackoverflow.com/a/9024521/13923613 

28 """ 

29 

30 def __init__(self, seed: int = 1) -> None: 

31 self._state = seed 

32 

33 def __iter__(self) -> 'Faker': 

34 return self 

35 

36 def __next__(self) -> int: 

37 self._state = state = (self._state * 1103515245 + 12345) & 0x7FFFFFFF 

38 return state 

39 

40 def string(self) -> str: 

41 return ''.join([chr(97 + int(n)) for n in str(self.integer())]) 

42 

43 def boolean(self) -> bool: 

44 return next(self) % 2 == 0 

45 

46 def integer(self) -> int: 

47 return next(self) 

48 

49 @classmethod 

50 def from_list(cls, values: List[T]) -> T: 

51 # TODO: actual implementation 

52 assert values, 'Expected non-empty list' 

53 return values[0] 

54 

55 

56class Sampler: 

57 model: 'Model' 

58 

59 def __init__(self, model: 'Model') -> None: 

60 self.model = model 

61 SAMPLER_ITER_MAPPING[model.name] = model.scalar_fields 

62 

63 def get_field(self) -> 'Field': 

64 mapping = SAMPLER_ITER_MAPPING 

65 

66 try: 

67 field = next(mapping[self.model.name]) 

68 except StopIteration: 

69 mapping[self.model.name] = field_iter = self.model.scalar_fields 

70 field = next(field_iter) 

71 

72 return field 

73 

74 

75def is_same_path(path: Path, other: Path) -> bool: 

76 return str(path.resolve()).strip() == str(other.resolve()).strip() 

77 

78 

79def resolve_template_path(rootdir: Path, name: Union[str, Path]) -> Path: 

80 return rootdir.joinpath(remove_suffix(name, '.jinja')) 

81 

82 

83def remove_suffix(path: Union[str, Path], suf: str) -> str: 

84 """Remove a suffix from a string, if it exists.""" 

85 # modified from https://stackoverflow.com/a/18723694 

86 if isinstance(path, Path): 

87 path = str(path) 

88 

89 if suf and path.endswith(suf): 89 ↛ 91line 89 didn't jump to line 91, because the condition on line 89 was never false

90 return path[: -len(suf)] 

91 return path 

92 

93 

94def copy_tree(src: Path, dst: Path) -> None: 

95 """Recursively copy the contents of a directory from src to dst. 

96 

97 This function will ignore certain compiled / cache files for convenience: 

98 - *.pyc 

99 - __pycache__ 

100 """ 

101 # we have to do this horrible monkeypatching as 

102 # shutil makes an arbitrary call to os.makedirs 

103 # which will fail if the directory already exists. 

104 # the dirs_exist_ok argument does exist but was only 

105 # added in python 3.8 so we cannot use that :( 

106 

107 def _patched_makedirs( 

108 makedirs: Any, 

109 name: str, 

110 mode: int = 511, 

111 exist_ok: bool = True, # noqa: ARG001 

112 ) -> None: 

113 makedirs(name, mode, exist_ok=True) 

114 

115 with monkeypatch(os, 'makedirs', _patched_makedirs): 

116 shutil.copytree( 

117 str(src), 

118 str(dst), 

119 ignore=shutil.ignore_patterns('*.pyc', '__pycache__'), 

120 ) 

121 

122 

123def clean_multiline(string: str) -> str: 

124 string = string.lstrip('\n') 

125 assert string, 'Expected non-empty string' 

126 lines = string.splitlines() 

127 return '\n'.join([dedent(lines[0]), *lines[1:]]) 

128 

129 

130# https://github.com/nficano/humps/blob/master/humps/main.py 

131 

132ACRONYM_RE = re.compile(r'([A-Z\d]+)(?=[A-Z\d]|$)') 

133PASCAL_RE = re.compile(r'([^\-_]+)') 

134SPLIT_RE = re.compile(r'([\-_]*[A-Z][^A-Z]*[\-_]*)') 

135UNDERSCORE_RE = re.compile(r'(?<=[^\-_])[\-_]+[^\-_]') 

136 

137 

138def to_snake_case(input_str: str) -> str: 

139 if to_camel_case(input_str) == input_str or to_pascal_case(input_str) == input_str: # if camel case or pascal case 

140 input_str = ACRONYM_RE.sub(lambda m: m.group(0).title(), input_str) 140 ↛ exitline 140 didn't run the lambda on line 140

141 input_str = '_'.join(s for s in SPLIT_RE.split(input_str) if s) 

142 return input_str.lower() 

143 else: 

144 input_str = re.sub(r'[^a-zA-Z0-9]', '_', input_str) 

145 input_str = input_str.lower().strip('_') 

146 

147 return input_str 

148 

149 

150def to_camel_case(input_str: str) -> str: 

151 if len(input_str) != 0 and not input_str[:2].isupper(): 151 ↛ 153line 151 didn't jump to line 153, because the condition on line 151 was never false

152 input_str = input_str[0].lower() + input_str[1:] 

153 return UNDERSCORE_RE.sub(lambda m: m.group(0)[-1].upper(), input_str) 

154 

155 

156def to_pascal_case(input_str: str) -> str: 

157 def _replace_fn(match: re.Match[str]) -> str: 

158 return match.group(1)[0].upper() + match.group(1)[1:] 

159 

160 input_str = to_camel_case(PASCAL_RE.sub(_replace_fn, input_str)) 

161 return input_str[0].upper() + input_str[1:] if len(input_str) != 0 else input_str 

162 

163 

164def to_constant_case(input_str: str) -> str: 

165 """Converts to snake case + uppercase, examples: 

166 

167 foo_bar -> FOO_BAR 

168 fooBar -> FOO_BAR 

169 """ 

170 return to_snake_case(input_str).upper()