Coverage for databases/tests/test_group_by.py: 100%

124 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-08-27 18:25 +0000

1import pytest 

2from syrupy.assertion import SnapshotAssertion 

3 

4import prisma 

5from prisma import Prisma 

6from lib.testing import async_fixture 

7from prisma.types import SortOrder 

8 

9# TODO: test all types 

10# TODO: test working with the results 

11 

12 

13@async_fixture(autouse=True, scope='session') 

14async def create_test_data(client: Prisma) -> None: 

15 create = client.profile.create 

16 await create( 

17 { 

18 'description': 'from scotland', 

19 'country': 'Scotland', 

20 'city': 'Edinburgh', 

21 'views': 250, 

22 'user': {'create': {'name': 'Tegan'}}, 

23 } 

24 ) 

25 

26 for _ in range(12): 

27 await create( 

28 { 

29 'description': 'description', 

30 'country': 'Denmark', 

31 'views': 500, 

32 'user': {'create': {'name': 'Robert'}}, 

33 } 

34 ) 

35 

36 for _ in range(8): 

37 await create( 

38 { 

39 'description': 'description', 

40 'country': 'Denmark', 

41 'city': 'Copenhagen', 

42 'views': 1000, 

43 'user': {'create': {'name': 'Robert'}}, 

44 } 

45 ) 

46 

47 types_create = client.types.create 

48 for i in range(10): 

49 await types_create( 

50 { 

51 'integer': i, 

52 }, 

53 ) 

54 

55 

56@pytest.mark.asyncio 

57@pytest.mark.persist_data 

58async def test_group_by(snapshot: SnapshotAssertion, client: Prisma) -> None: 

59 """Basic test grouping by 1 field with no additional filters""" 

60 assert ( 

61 await client.user.group_by( 

62 ['name'], 

63 order={ 

64 'name': 'asc', 

65 }, 

66 ) 

67 == snapshot 

68 ) 

69 assert ( 

70 await client.profile.group_by( 

71 ['country'], 

72 order={ 

73 'country': 'asc', 

74 }, 

75 ) 

76 == snapshot 

77 ) 

78 

79 

80@pytest.mark.asyncio 

81@pytest.mark.persist_data 

82async def test_docs_example(snapshot: SnapshotAssertion, client: Prisma) -> None: 

83 """Test the example given in the Prisma documentation: 

84 https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#groupby 

85 """ 

86 results = await client.profile.group_by( 

87 by=['country', 'city'], 

88 count={ 

89 '_all': True, 

90 'city': True, 

91 }, 

92 sum={ 

93 'views': True, 

94 }, 

95 order={ 

96 'country': 'desc', 

97 }, 

98 having={ 

99 'views': { 

100 '_avg': { 

101 'gt': 200, 

102 }, 

103 }, 

104 }, 

105 ) 

106 assert results == snapshot 

107 

108 

109@pytest.mark.asyncio 

110@pytest.mark.persist_data 

111@pytest.mark.parametrize('order', ['asc', 'desc']) 

112async def test_order(snapshot: SnapshotAssertion, client: Prisma, order: SortOrder) -> None: 

113 """Test ordering results by a grouped field""" 

114 assert await client.profile.group_by(['country'], order={'country': order}) == snapshot 

115 

116 

117@pytest.mark.asyncio 

118@pytest.mark.persist_data 

119async def test_order_list(snapshot: SnapshotAssertion, client: Prisma) -> None: 

120 """Test ordering results by a list of grouped fields""" 

121 results = await client.profile.group_by( 

122 by=['country', 'city'], 

123 order=[ 

124 {'country': 'asc'}, 

125 {'city': 'desc'}, 

126 ], 

127 ) 

128 # we have to apply this sorted operation as SQlite and PostgreSQL 

129 # have different default behaviour for sorting by nulls 

130 # and we don't support changing it yet 

131 results = sorted(results, key=lambda p: p.get('city') is not None) 

132 assert results == snapshot 

133 

134 

135@pytest.mark.asyncio 

136@pytest.mark.persist_data 

137async def test_order_multiple_fields(client: Prisma) -> None: 

138 """Test ordering results by multiple fields is not support""" 

139 with pytest.raises(prisma.errors.DataError): 

140 await client.profile.group_by( 

141 ['country', 'city'], 

142 order={ 

143 'city': 'desc', 

144 'country': 'asc', 

145 }, 

146 ) 

147 

148 

149@pytest.mark.asyncio 

150@pytest.mark.persist_data 

151async def test_order_mismatched_arguments(client: Prisma) -> None: 

152 """The order argument only accepts fields that are being grouped""" 

153 with pytest.raises(prisma.errors.InputError) as exc: 

154 await client.profile.group_by( 

155 ['city'], 

156 order={ # pyright: ignore 

157 'country': 'asc', 

158 }, 

159 ) 

160 

161 assert exc.match( 

162 r'Every field used for orderBy must be included in the by-arguments of the query\. ' r'Missing fields: country' 

163 ) 

164 

165 

166@pytest.mark.asyncio 

167@pytest.mark.persist_data 

168@pytest.mark.parametrize('order', ['asc', 'desc']) 

169async def test_take( 

170 snapshot: SnapshotAssertion, 

171 client: Prisma, 

172 order: SortOrder, 

173) -> None: 

174 """Take argument limits number of records returned""" 

175 assert ( 

176 await client.profile.group_by( 

177 ['country'], 

178 take=1, 

179 order={'country': order}, 

180 ) 

181 == snapshot 

182 ) 

183 

184 

185@pytest.mark.asyncio 

186@pytest.mark.persist_data 

187async def test_take_missing_order_argument(client: Prisma) -> None: 

188 """The order argument must be provided to use take""" 

189 with pytest.raises(TypeError) as exc: 

190 await client.profile.group_by(['country'], take=1) 

191 

192 assert exc.match("Missing argument: 'order' which is required when 'take' is present") 

193 

194 

195@pytest.mark.asyncio 

196@pytest.mark.persist_data 

197@pytest.mark.parametrize('order', ['asc', 'desc']) 

198async def test_skip( 

199 snapshot: SnapshotAssertion, 

200 client: Prisma, 

201 order: SortOrder, 

202) -> None: 

203 """Skipping grouped records""" 

204 assert ( 

205 await client.profile.group_by( 

206 ['country'], 

207 skip=1, 

208 order={'country': order}, 

209 ) 

210 == snapshot 

211 ) 

212 

213 

214@pytest.mark.asyncio 

215@pytest.mark.persist_data 

216async def test_skip_missing_order_argument(client: Prisma) -> None: 

217 """The order argument must be provided to use skip""" 

218 with pytest.raises(TypeError) as exc: 

219 await client.profile.group_by(['country'], skip=1) 

220 

221 assert exc.match("Missing argument: 'order' which is required when 'skip' is present") 

222 

223 

224@pytest.mark.asyncio 

225@pytest.mark.persist_data 

226async def test_where(client: Prisma) -> None: 

227 """Where argument correctly filters records""" 

228 results = await client.profile.group_by( 

229 ['country'], 

230 where={ 

231 'country': 'Denmark', 

232 }, 

233 order={ 

234 'country': 'asc', 

235 }, 

236 ) 

237 assert len(results) == 1 

238 assert results[0].get('country') == 'Denmark' 

239 

240 results = await client.profile.group_by( 

241 ['country'], 

242 where={ 

243 'description': { 

244 'contains': 'scotland', 

245 }, 

246 }, 

247 order={ 

248 'country': 'asc', 

249 }, 

250 ) 

251 assert len(results) == 1 

252 assert results[0].get('country') == 'Scotland' 

253 

254 

255@pytest.mark.asyncio 

256@pytest.mark.persist_data 

257async def test_having_missing_field_in_by(client: Prisma) -> None: 

258 """Having filters must be an aggregation filter or be included in by""" 

259 with pytest.raises(prisma.errors.InputError) as exc: 

260 await client.profile.group_by( 

261 by=['country'], 

262 count=True, 

263 having={ 

264 'views': { 

265 'gt': 50, 

266 }, 

267 }, 

268 order={ 

269 'country': 'asc', 

270 }, 

271 ) 

272 

273 assert exc.match( 

274 'Input error. Every field used in `having` filters must either be an aggregation filter ' 

275 'or be included in the selection of the query. Missing fields: views' 

276 ) 

277 

278 

279@pytest.mark.asyncio 

280@pytest.mark.persist_data 

281async def test_having_aggregation(snapshot: SnapshotAssertion, client: Prisma) -> None: 

282 """Having aggregation filters records correctly""" 

283 assert ( 

284 await client.profile.group_by( 

285 by=['country'], 

286 count=True, 

287 having={ 

288 'views': { 

289 '_avg': { 

290 'gt': 600, 

291 } 

292 } 

293 }, 

294 order={ 

295 'country': 'asc', 

296 }, 

297 ) 

298 == snapshot 

299 ) 

300 assert ( 

301 await client.profile.group_by( 

302 by=['country'], 

303 count=True, 

304 having={ 

305 'views': { 

306 '_avg': { 

307 'lt': 600, 

308 } 

309 } 

310 }, 

311 order={ 

312 'country': 'asc', 

313 }, 

314 ) 

315 == snapshot 

316 ) 

317 

318 

319@pytest.mark.asyncio 

320@pytest.mark.persist_data 

321async def test_having_aggregation_nested(snapshot: SnapshotAssertion, client: Prisma) -> None: 

322 """Having aggregation filters nested within statements correctly filters records""" 

323 results = await client.profile.group_by( 

324 by=['country'], 

325 count=True, 

326 having={ 

327 'OR': [ 

328 { 

329 'views': { 

330 '_avg': { 

331 'equals': 1000, 

332 }, 

333 }, 

334 }, 

335 { 

336 'views': { 

337 '_sum': { 

338 'equals': 250, 

339 }, 

340 }, 

341 }, 

342 ], 

343 }, 

344 order={ 

345 'country': 'asc', 

346 }, 

347 ) 

348 assert results == snapshot 

349 

350 results = await client.profile.group_by( 

351 by=['country'], 

352 count=True, 

353 having={ 

354 'OR': [ 

355 { 

356 'views': { 

357 '_avg': { 

358 'equals': 700, 

359 }, 

360 }, 

361 }, 

362 { 

363 'views': { 

364 '_sum': { 

365 'equals': 250, 

366 }, 

367 }, 

368 }, 

369 ], 

370 }, 

371 order={ 

372 'country': 'asc', 

373 }, 

374 ) 

375 assert results == snapshot 

376 

377 results = await client.profile.group_by( 

378 by=['country'], 

379 count=True, 

380 having={ 

381 'OR': [ 

382 { 

383 'views': { 

384 '_avg': { 

385 'equals': 700, 

386 }, 

387 }, 

388 }, 

389 { 

390 'views': { 

391 '_sum': { 

392 'equals': 250, 

393 }, 

394 }, 

395 'NOT': [ 

396 { 

397 'views': { 

398 '_min': { 

399 'equals': 250, 

400 }, 

401 }, 

402 }, 

403 ], 

404 }, 

405 ], 

406 }, 

407 order={ 

408 'country': 'asc', 

409 }, 

410 ) 

411 assert results == snapshot 

412 

413 

414@pytest.mark.asyncio 

415@pytest.mark.persist_data 

416async def test_count(snapshot: SnapshotAssertion, client: Prisma) -> None: 

417 """Counting records""" 

418 assert ( 

419 await client.profile.group_by( 

420 ['country'], 

421 count=True, 

422 order={ 

423 'country': 'asc', 

424 }, 

425 ) 

426 == snapshot 

427 ) 

428 assert ( 

429 await client.profile.group_by( 

430 ['country'], 

431 count={ 

432 '_all': True, 

433 }, 

434 order={ 

435 'country': 'asc', 

436 }, 

437 ) 

438 == snapshot 

439 ) 

440 assert ( 

441 await client.profile.group_by( 

442 ['country'], 

443 count={ 

444 'city': True, 

445 }, 

446 order={ 

447 'country': 'asc', 

448 }, 

449 ) 

450 == snapshot 

451 ) 

452 assert ( 

453 await client.profile.group_by( 

454 ['country'], 

455 count={ 

456 'city': True, 

457 'country': True, 

458 }, 

459 order={ 

460 'country': 'asc', 

461 }, 

462 ) 

463 == snapshot 

464 ) 

465 

466 

467@pytest.mark.asyncio 

468@pytest.mark.persist_data 

469async def test_avg(snapshot: SnapshotAssertion, client: Prisma) -> None: 

470 """Getting the average of records""" 

471 assert ( 

472 await client.profile.group_by( 

473 ['country'], 

474 avg={'views': True}, 

475 order={'country': 'asc'}, 

476 ) 

477 == snapshot 

478 ) 

479 assert ( 

480 await client.types.group_by( 

481 ['string'], 

482 avg={'integer': True, 'bigint': True}, 

483 order={'string': 'asc'}, 

484 ) 

485 == snapshot 

486 ) 

487 

488 

489@pytest.mark.asyncio 

490@pytest.mark.persist_data 

491async def test_sum(snapshot: SnapshotAssertion, client: Prisma) -> None: 

492 """Getting the sum of records""" 

493 assert ( 

494 await client.profile.group_by( 

495 ['country'], 

496 sum={ 

497 'views': True, 

498 }, 

499 order={ 

500 'country': 'asc', 

501 }, 

502 ) 

503 == snapshot 

504 ) 

505 

506 

507@pytest.mark.asyncio 

508@pytest.mark.persist_data 

509async def test_min(snapshot: SnapshotAssertion, client: Prisma) -> None: 

510 """Getting the minimum value of records""" 

511 assert ( 

512 await client.profile.group_by( 

513 ['country'], 

514 min={ 

515 'views': True, 

516 }, 

517 order={ 

518 'country': 'asc', 

519 }, 

520 ) 

521 == snapshot 

522 ) 

523 

524 

525@pytest.mark.asyncio 

526@pytest.mark.persist_data 

527async def test_max(snapshot: SnapshotAssertion, client: Prisma) -> None: 

528 """Getting the maximum value of records""" 

529 assert ( 

530 await client.profile.group_by( 

531 ['country'], 

532 max={ 

533 'views': True, 

534 }, 

535 order={ 

536 'country': 'asc', 

537 }, 

538 ) 

539 == snapshot 

540 )