Coverage for databases/sync_tests/test_group_by.py: 100%

105 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1import pytest 

2from syrupy.assertion import SnapshotAssertion 

3 

4import prisma 

5from prisma import Prisma 

6from lib.testing import async_fixture 

7from prisma.types import SortOrder 

8 

9# TODO: test all types 

10# TODO: test working with the results 

11 

12 

13@async_fixture(autouse=True, scope='session') 

14def create_test_data(client: Prisma) -> None: 

15 create = client.profile.create 

16 create( 

17 { 

18 'description': 'from scotland', 

19 'country': 'Scotland', 

20 'city': 'Edinburgh', 

21 'views': 250, 

22 'user': {'create': {'name': 'Tegan'}}, 

23 } 

24 ) 

25 

26 for _ in range(12): 

27 create( 

28 { 

29 'description': 'description', 

30 'country': 'Denmark', 

31 'views': 500, 

32 'user': {'create': {'name': 'Robert'}}, 

33 } 

34 ) 

35 

36 for _ in range(8): 

37 create( 

38 { 

39 'description': 'description', 

40 'country': 'Denmark', 

41 'city': 'Copenhagen', 

42 'views': 1000, 

43 'user': {'create': {'name': 'Robert'}}, 

44 } 

45 ) 

46 

47 types_create = client.types.create 

48 for i in range(10): 

49 types_create( 

50 { 

51 'integer': i, 

52 }, 

53 ) 

54 

55 

56@pytest.mark.persist_data 

57def test_group_by(snapshot: SnapshotAssertion, client: Prisma) -> None: 

58 """Basic test grouping by 1 field with no additional filters""" 

59 assert ( 

60 client.user.group_by( 

61 ['name'], 

62 order={ 

63 'name': 'asc', 

64 }, 

65 ) 

66 == snapshot 

67 ) 

68 assert ( 

69 client.profile.group_by( 

70 ['country'], 

71 order={ 

72 'country': 'asc', 

73 }, 

74 ) 

75 == snapshot 

76 ) 

77 

78 

79@pytest.mark.persist_data 

80def test_docs_example(snapshot: SnapshotAssertion, client: Prisma) -> None: 

81 """Test the example given in the Prisma documentation: 

82 https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#groupby 

83 """ 

84 results = client.profile.group_by( 

85 by=['country', 'city'], 

86 count={ 

87 '_all': True, 

88 'city': True, 

89 }, 

90 sum={ 

91 'views': True, 

92 }, 

93 order={ 

94 'country': 'desc', 

95 }, 

96 having={ 

97 'views': { 

98 '_avg': { 

99 'gt': 200, 

100 }, 

101 }, 

102 }, 

103 ) 

104 assert results == snapshot 

105 

106 

107@pytest.mark.persist_data 

108@pytest.mark.parametrize('order', ['asc', 'desc']) 

109def test_order(snapshot: SnapshotAssertion, client: Prisma, order: SortOrder) -> None: 

110 """Test ordering results by a grouped field""" 

111 assert client.profile.group_by(['country'], order={'country': order}) == snapshot 

112 

113 

114@pytest.mark.persist_data 

115def test_order_list(snapshot: SnapshotAssertion, client: Prisma) -> None: 

116 """Test ordering results by a list of grouped fields""" 

117 results = client.profile.group_by( 

118 by=['country', 'city'], 

119 order=[ 

120 {'country': 'asc'}, 

121 {'city': 'desc'}, 

122 ], 

123 ) 

124 # we have to apply this sorted operation as SQlite and PostgreSQL 

125 # have different default behaviour for sorting by nulls 

126 # and we don't support changing it yet 

127 results = sorted(results, key=lambda p: p.get('city') is not None) 

128 assert results == snapshot 

129 

130 

131@pytest.mark.persist_data 

132def test_order_multiple_fields(client: Prisma) -> None: 

133 """Test ordering results by multiple fields is not support""" 

134 with pytest.raises(prisma.errors.DataError): 

135 client.profile.group_by( 

136 ['country', 'city'], 

137 order={ 

138 'city': 'desc', 

139 'country': 'asc', 

140 }, 

141 ) 

142 

143 

144@pytest.mark.persist_data 

145def test_order_mismatched_arguments(client: Prisma) -> None: 

146 """The order argument only accepts fields that are being grouped""" 

147 with pytest.raises(prisma.errors.InputError) as exc: 

148 client.profile.group_by( 

149 ['city'], 

150 order={ # pyright: ignore 

151 'country': 'asc', 

152 }, 

153 ) 

154 

155 assert exc.match( 

156 r'Every field used for orderBy must be included in the by-arguments of the query\. ' r'Missing fields: country' 

157 ) 

158 

159 

160@pytest.mark.persist_data 

161@pytest.mark.parametrize('order', ['asc', 'desc']) 

162def test_take( 

163 snapshot: SnapshotAssertion, 

164 client: Prisma, 

165 order: SortOrder, 

166) -> None: 

167 """Take argument limits number of records returned""" 

168 assert ( 

169 client.profile.group_by( 

170 ['country'], 

171 take=1, 

172 order={'country': order}, 

173 ) 

174 == snapshot 

175 ) 

176 

177 

178@pytest.mark.persist_data 

179def test_take_missing_order_argument(client: Prisma) -> None: 

180 """The order argument must be provided to use take""" 

181 with pytest.raises(TypeError) as exc: 

182 client.profile.group_by(['country'], take=1) 

183 

184 assert exc.match("Missing argument: 'order' which is required when 'take' is present") 

185 

186 

187@pytest.mark.persist_data 

188@pytest.mark.parametrize('order', ['asc', 'desc']) 

189def test_skip( 

190 snapshot: SnapshotAssertion, 

191 client: Prisma, 

192 order: SortOrder, 

193) -> None: 

194 """Skipping grouped records""" 

195 assert ( 

196 client.profile.group_by( 

197 ['country'], 

198 skip=1, 

199 order={'country': order}, 

200 ) 

201 == snapshot 

202 ) 

203 

204 

205@pytest.mark.persist_data 

206def test_skip_missing_order_argument(client: Prisma) -> None: 

207 """The order argument must be provided to use skip""" 

208 with pytest.raises(TypeError) as exc: 

209 client.profile.group_by(['country'], skip=1) 

210 

211 assert exc.match("Missing argument: 'order' which is required when 'skip' is present") 

212 

213 

214@pytest.mark.persist_data 

215def test_where(client: Prisma) -> None: 

216 """Where argument correctly filters records""" 

217 results = client.profile.group_by( 

218 ['country'], 

219 where={ 

220 'country': 'Denmark', 

221 }, 

222 order={ 

223 'country': 'asc', 

224 }, 

225 ) 

226 assert len(results) == 1 

227 assert results[0].get('country') == 'Denmark' 

228 

229 results = client.profile.group_by( 

230 ['country'], 

231 where={ 

232 'description': { 

233 'contains': 'scotland', 

234 }, 

235 }, 

236 order={ 

237 'country': 'asc', 

238 }, 

239 ) 

240 assert len(results) == 1 

241 assert results[0].get('country') == 'Scotland' 

242 

243 

244@pytest.mark.persist_data 

245def test_having_missing_field_in_by(client: Prisma) -> None: 

246 """Having filters must be an aggregation filter or be included in by""" 

247 with pytest.raises(prisma.errors.InputError) as exc: 

248 client.profile.group_by( 

249 by=['country'], 

250 count=True, 

251 having={ 

252 'views': { 

253 'gt': 50, 

254 }, 

255 }, 

256 order={ 

257 'country': 'asc', 

258 }, 

259 ) 

260 

261 assert exc.match( 

262 'Input error. Every field used in `having` filters must either be an aggregation filter ' 

263 'or be included in the selection of the query. Missing fields: views' 

264 ) 

265 

266 

267@pytest.mark.persist_data 

268def test_having_aggregation(snapshot: SnapshotAssertion, client: Prisma) -> None: 

269 """Having aggregation filters records correctly""" 

270 assert ( 

271 client.profile.group_by( 

272 by=['country'], 

273 count=True, 

274 having={ 

275 'views': { 

276 '_avg': { 

277 'gt': 600, 

278 } 

279 } 

280 }, 

281 order={ 

282 'country': 'asc', 

283 }, 

284 ) 

285 == snapshot 

286 ) 

287 assert ( 

288 client.profile.group_by( 

289 by=['country'], 

290 count=True, 

291 having={ 

292 'views': { 

293 '_avg': { 

294 'lt': 600, 

295 } 

296 } 

297 }, 

298 order={ 

299 'country': 'asc', 

300 }, 

301 ) 

302 == snapshot 

303 ) 

304 

305 

306@pytest.mark.persist_data 

307def test_having_aggregation_nested(snapshot: SnapshotAssertion, client: Prisma) -> None: 

308 """Having aggregation filters nested within statements correctly filters records""" 

309 results = client.profile.group_by( 

310 by=['country'], 

311 count=True, 

312 having={ 

313 'OR': [ 

314 { 

315 'views': { 

316 '_avg': { 

317 'equals': 1000, 

318 }, 

319 }, 

320 }, 

321 { 

322 'views': { 

323 '_sum': { 

324 'equals': 250, 

325 }, 

326 }, 

327 }, 

328 ], 

329 }, 

330 order={ 

331 'country': 'asc', 

332 }, 

333 ) 

334 assert results == snapshot 

335 

336 results = client.profile.group_by( 

337 by=['country'], 

338 count=True, 

339 having={ 

340 'OR': [ 

341 { 

342 'views': { 

343 '_avg': { 

344 'equals': 700, 

345 }, 

346 }, 

347 }, 

348 { 

349 'views': { 

350 '_sum': { 

351 'equals': 250, 

352 }, 

353 }, 

354 }, 

355 ], 

356 }, 

357 order={ 

358 'country': 'asc', 

359 }, 

360 ) 

361 assert results == snapshot 

362 

363 results = client.profile.group_by( 

364 by=['country'], 

365 count=True, 

366 having={ 

367 'OR': [ 

368 { 

369 'views': { 

370 '_avg': { 

371 'equals': 700, 

372 }, 

373 }, 

374 }, 

375 { 

376 'views': { 

377 '_sum': { 

378 'equals': 250, 

379 }, 

380 }, 

381 'NOT': [ 

382 { 

383 'views': { 

384 '_min': { 

385 'equals': 250, 

386 }, 

387 }, 

388 }, 

389 ], 

390 }, 

391 ], 

392 }, 

393 order={ 

394 'country': 'asc', 

395 }, 

396 ) 

397 assert results == snapshot 

398 

399 

400@pytest.mark.persist_data 

401def test_count(snapshot: SnapshotAssertion, client: Prisma) -> None: 

402 """Counting records""" 

403 assert ( 

404 client.profile.group_by( 

405 ['country'], 

406 count=True, 

407 order={ 

408 'country': 'asc', 

409 }, 

410 ) 

411 == snapshot 

412 ) 

413 assert ( 

414 client.profile.group_by( 

415 ['country'], 

416 count={ 

417 '_all': True, 

418 }, 

419 order={ 

420 'country': 'asc', 

421 }, 

422 ) 

423 == snapshot 

424 ) 

425 assert ( 

426 client.profile.group_by( 

427 ['country'], 

428 count={ 

429 'city': True, 

430 }, 

431 order={ 

432 'country': 'asc', 

433 }, 

434 ) 

435 == snapshot 

436 ) 

437 assert ( 

438 client.profile.group_by( 

439 ['country'], 

440 count={ 

441 'city': True, 

442 'country': True, 

443 }, 

444 order={ 

445 'country': 'asc', 

446 }, 

447 ) 

448 == snapshot 

449 ) 

450 

451 

452@pytest.mark.persist_data 

453def test_avg(snapshot: SnapshotAssertion, client: Prisma) -> None: 

454 """Getting the average of records""" 

455 assert ( 

456 client.profile.group_by( 

457 ['country'], 

458 avg={'views': True}, 

459 order={'country': 'asc'}, 

460 ) 

461 == snapshot 

462 ) 

463 assert ( 

464 client.types.group_by( 

465 ['string'], 

466 avg={'integer': True, 'bigint': True}, 

467 order={'string': 'asc'}, 

468 ) 

469 == snapshot 

470 ) 

471 

472 

473@pytest.mark.persist_data 

474def test_sum(snapshot: SnapshotAssertion, client: Prisma) -> None: 

475 """Getting the sum of records""" 

476 assert ( 

477 client.profile.group_by( 

478 ['country'], 

479 sum={ 

480 'views': True, 

481 }, 

482 order={ 

483 'country': 'asc', 

484 }, 

485 ) 

486 == snapshot 

487 ) 

488 

489 

490@pytest.mark.persist_data 

491def test_min(snapshot: SnapshotAssertion, client: Prisma) -> None: 

492 """Getting the minimum value of records""" 

493 assert ( 

494 client.profile.group_by( 

495 ['country'], 

496 min={ 

497 'views': True, 

498 }, 

499 order={ 

500 'country': 'asc', 

501 }, 

502 ) 

503 == snapshot 

504 ) 

505 

506 

507@pytest.mark.persist_data 

508def test_max(snapshot: SnapshotAssertion, client: Prisma) -> None: 

509 """Getting the maximum value of records""" 

510 assert ( 

511 client.profile.group_by( 

512 ['country'], 

513 max={ 

514 'views': True, 

515 }, 

516 order={ 

517 'country': 'asc', 

518 }, 

519 ) 

520 == snapshot 

521 )