Coverage for databases/tests/test_group_by.py: 100%
124 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1import pytest
2from syrupy.assertion import SnapshotAssertion
4import prisma
5from prisma import Prisma
6from lib.testing import async_fixture
7from prisma.types import SortOrder
9# TODO: test all types
10# TODO: test working with the results
13@async_fixture(autouse=True, scope='session')
14async def create_test_data(client: Prisma) -> None:
15 create = client.profile.create
16 await create(
17 {
18 'description': 'from scotland',
19 'country': 'Scotland',
20 'city': 'Edinburgh',
21 'views': 250,
22 'user': {'create': {'name': 'Tegan'}},
23 }
24 )
26 for _ in range(12):
27 await create(
28 {
29 'description': 'description',
30 'country': 'Denmark',
31 'views': 500,
32 'user': {'create': {'name': 'Robert'}},
33 }
34 )
36 for _ in range(8):
37 await create(
38 {
39 'description': 'description',
40 'country': 'Denmark',
41 'city': 'Copenhagen',
42 'views': 1000,
43 'user': {'create': {'name': 'Robert'}},
44 }
45 )
47 types_create = client.types.create
48 for i in range(10):
49 await types_create(
50 {
51 'integer': i,
52 },
53 )
56@pytest.mark.asyncio
57@pytest.mark.persist_data
58async def test_group_by(snapshot: SnapshotAssertion, client: Prisma) -> None:
59 """Basic test grouping by 1 field with no additional filters"""
60 assert (
61 await client.user.group_by(
62 ['name'],
63 order={
64 'name': 'asc',
65 },
66 )
67 == snapshot
68 )
69 assert (
70 await client.profile.group_by(
71 ['country'],
72 order={
73 'country': 'asc',
74 },
75 )
76 == snapshot
77 )
80@pytest.mark.asyncio
81@pytest.mark.persist_data
82async def test_docs_example(snapshot: SnapshotAssertion, client: Prisma) -> None:
83 """Test the example given in the Prisma documentation:
84 https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#groupby
85 """
86 results = await client.profile.group_by(
87 by=['country', 'city'],
88 count={
89 '_all': True,
90 'city': True,
91 },
92 sum={
93 'views': True,
94 },
95 order={
96 'country': 'desc',
97 },
98 having={
99 'views': {
100 '_avg': {
101 'gt': 200,
102 },
103 },
104 },
105 )
106 assert results == snapshot
109@pytest.mark.asyncio
110@pytest.mark.persist_data
111@pytest.mark.parametrize('order', ['asc', 'desc'])
112async def test_order(snapshot: SnapshotAssertion, client: Prisma, order: SortOrder) -> None:
113 """Test ordering results by a grouped field"""
114 assert await client.profile.group_by(['country'], order={'country': order}) == snapshot
117@pytest.mark.asyncio
118@pytest.mark.persist_data
119async def test_order_list(snapshot: SnapshotAssertion, client: Prisma) -> None:
120 """Test ordering results by a list of grouped fields"""
121 results = await client.profile.group_by(
122 by=['country', 'city'],
123 order=[
124 {'country': 'asc'},
125 {'city': 'desc'},
126 ],
127 )
128 # we have to apply this sorted operation as SQlite and PostgreSQL
129 # have different default behaviour for sorting by nulls
130 # and we don't support changing it yet
131 results = sorted(results, key=lambda p: p.get('city') is not None)
132 assert results == snapshot
135@pytest.mark.asyncio
136@pytest.mark.persist_data
137async def test_order_multiple_fields(client: Prisma) -> None:
138 """Test ordering results by multiple fields is not support"""
139 with pytest.raises(prisma.errors.DataError):
140 await client.profile.group_by(
141 ['country', 'city'],
142 order={
143 'city': 'desc',
144 'country': 'asc',
145 },
146 )
149@pytest.mark.asyncio
150@pytest.mark.persist_data
151async def test_order_mismatched_arguments(client: Prisma) -> None:
152 """The order argument only accepts fields that are being grouped"""
153 with pytest.raises(prisma.errors.InputError) as exc:
154 await client.profile.group_by(
155 ['city'],
156 order={ # pyright: ignore
157 'country': 'asc',
158 },
159 )
161 assert exc.match(
162 r'Every field used for orderBy must be included in the by-arguments of the query\. ' r'Missing fields: country'
163 )
166@pytest.mark.asyncio
167@pytest.mark.persist_data
168@pytest.mark.parametrize('order', ['asc', 'desc'])
169async def test_take(
170 snapshot: SnapshotAssertion,
171 client: Prisma,
172 order: SortOrder,
173) -> None:
174 """Take argument limits number of records returned"""
175 assert (
176 await client.profile.group_by(
177 ['country'],
178 take=1,
179 order={'country': order},
180 )
181 == snapshot
182 )
185@pytest.mark.asyncio
186@pytest.mark.persist_data
187async def test_take_missing_order_argument(client: Prisma) -> None:
188 """The order argument must be provided to use take"""
189 with pytest.raises(TypeError) as exc:
190 await client.profile.group_by(['country'], take=1)
192 assert exc.match("Missing argument: 'order' which is required when 'take' is present")
195@pytest.mark.asyncio
196@pytest.mark.persist_data
197@pytest.mark.parametrize('order', ['asc', 'desc'])
198async def test_skip(
199 snapshot: SnapshotAssertion,
200 client: Prisma,
201 order: SortOrder,
202) -> None:
203 """Skipping grouped records"""
204 assert (
205 await client.profile.group_by(
206 ['country'],
207 skip=1,
208 order={'country': order},
209 )
210 == snapshot
211 )
214@pytest.mark.asyncio
215@pytest.mark.persist_data
216async def test_skip_missing_order_argument(client: Prisma) -> None:
217 """The order argument must be provided to use skip"""
218 with pytest.raises(TypeError) as exc:
219 await client.profile.group_by(['country'], skip=1)
221 assert exc.match("Missing argument: 'order' which is required when 'skip' is present")
224@pytest.mark.asyncio
225@pytest.mark.persist_data
226async def test_where(client: Prisma) -> None:
227 """Where argument correctly filters records"""
228 results = await client.profile.group_by(
229 ['country'],
230 where={
231 'country': 'Denmark',
232 },
233 order={
234 'country': 'asc',
235 },
236 )
237 assert len(results) == 1
238 assert results[0].get('country') == 'Denmark'
240 results = await client.profile.group_by(
241 ['country'],
242 where={
243 'description': {
244 'contains': 'scotland',
245 },
246 },
247 order={
248 'country': 'asc',
249 },
250 )
251 assert len(results) == 1
252 assert results[0].get('country') == 'Scotland'
255@pytest.mark.asyncio
256@pytest.mark.persist_data
257async def test_having_missing_field_in_by(client: Prisma) -> None:
258 """Having filters must be an aggregation filter or be included in by"""
259 with pytest.raises(prisma.errors.InputError) as exc:
260 await client.profile.group_by(
261 by=['country'],
262 count=True,
263 having={
264 'views': {
265 'gt': 50,
266 },
267 },
268 order={
269 'country': 'asc',
270 },
271 )
273 assert exc.match(
274 'Input error. Every field used in `having` filters must either be an aggregation filter '
275 'or be included in the selection of the query. Missing fields: views'
276 )
279@pytest.mark.asyncio
280@pytest.mark.persist_data
281async def test_having_aggregation(snapshot: SnapshotAssertion, client: Prisma) -> None:
282 """Having aggregation filters records correctly"""
283 assert (
284 await client.profile.group_by(
285 by=['country'],
286 count=True,
287 having={
288 'views': {
289 '_avg': {
290 'gt': 600,
291 }
292 }
293 },
294 order={
295 'country': 'asc',
296 },
297 )
298 == snapshot
299 )
300 assert (
301 await client.profile.group_by(
302 by=['country'],
303 count=True,
304 having={
305 'views': {
306 '_avg': {
307 'lt': 600,
308 }
309 }
310 },
311 order={
312 'country': 'asc',
313 },
314 )
315 == snapshot
316 )
319@pytest.mark.asyncio
320@pytest.mark.persist_data
321async def test_having_aggregation_nested(snapshot: SnapshotAssertion, client: Prisma) -> None:
322 """Having aggregation filters nested within statements correctly filters records"""
323 results = await client.profile.group_by(
324 by=['country'],
325 count=True,
326 having={
327 'OR': [
328 {
329 'views': {
330 '_avg': {
331 'equals': 1000,
332 },
333 },
334 },
335 {
336 'views': {
337 '_sum': {
338 'equals': 250,
339 },
340 },
341 },
342 ],
343 },
344 order={
345 'country': 'asc',
346 },
347 )
348 assert results == snapshot
350 results = await client.profile.group_by(
351 by=['country'],
352 count=True,
353 having={
354 'OR': [
355 {
356 'views': {
357 '_avg': {
358 'equals': 700,
359 },
360 },
361 },
362 {
363 'views': {
364 '_sum': {
365 'equals': 250,
366 },
367 },
368 },
369 ],
370 },
371 order={
372 'country': 'asc',
373 },
374 )
375 assert results == snapshot
377 results = await client.profile.group_by(
378 by=['country'],
379 count=True,
380 having={
381 'OR': [
382 {
383 'views': {
384 '_avg': {
385 'equals': 700,
386 },
387 },
388 },
389 {
390 'views': {
391 '_sum': {
392 'equals': 250,
393 },
394 },
395 'NOT': [
396 {
397 'views': {
398 '_min': {
399 'equals': 250,
400 },
401 },
402 },
403 ],
404 },
405 ],
406 },
407 order={
408 'country': 'asc',
409 },
410 )
411 assert results == snapshot
414@pytest.mark.asyncio
415@pytest.mark.persist_data
416async def test_count(snapshot: SnapshotAssertion, client: Prisma) -> None:
417 """Counting records"""
418 assert (
419 await client.profile.group_by(
420 ['country'],
421 count=True,
422 order={
423 'country': 'asc',
424 },
425 )
426 == snapshot
427 )
428 assert (
429 await client.profile.group_by(
430 ['country'],
431 count={
432 '_all': True,
433 },
434 order={
435 'country': 'asc',
436 },
437 )
438 == snapshot
439 )
440 assert (
441 await client.profile.group_by(
442 ['country'],
443 count={
444 'city': True,
445 },
446 order={
447 'country': 'asc',
448 },
449 )
450 == snapshot
451 )
452 assert (
453 await client.profile.group_by(
454 ['country'],
455 count={
456 'city': True,
457 'country': True,
458 },
459 order={
460 'country': 'asc',
461 },
462 )
463 == snapshot
464 )
467@pytest.mark.asyncio
468@pytest.mark.persist_data
469async def test_avg(snapshot: SnapshotAssertion, client: Prisma) -> None:
470 """Getting the average of records"""
471 assert (
472 await client.profile.group_by(
473 ['country'],
474 avg={'views': True},
475 order={'country': 'asc'},
476 )
477 == snapshot
478 )
479 assert (
480 await client.types.group_by(
481 ['string'],
482 avg={'integer': True, 'bigint': True},
483 order={'string': 'asc'},
484 )
485 == snapshot
486 )
489@pytest.mark.asyncio
490@pytest.mark.persist_data
491async def test_sum(snapshot: SnapshotAssertion, client: Prisma) -> None:
492 """Getting the sum of records"""
493 assert (
494 await client.profile.group_by(
495 ['country'],
496 sum={
497 'views': True,
498 },
499 order={
500 'country': 'asc',
501 },
502 )
503 == snapshot
504 )
507@pytest.mark.asyncio
508@pytest.mark.persist_data
509async def test_min(snapshot: SnapshotAssertion, client: Prisma) -> None:
510 """Getting the minimum value of records"""
511 assert (
512 await client.profile.group_by(
513 ['country'],
514 min={
515 'views': True,
516 },
517 order={
518 'country': 'asc',
519 },
520 )
521 == snapshot
522 )
525@pytest.mark.asyncio
526@pytest.mark.persist_data
527async def test_max(snapshot: SnapshotAssertion, client: Prisma) -> None:
528 """Getting the maximum value of records"""
529 assert (
530 await client.profile.group_by(
531 ['country'],
532 max={
533 'views': True,
534 },
535 order={
536 'country': 'asc',
537 },
538 )
539 == snapshot
540 )