Coverage for databases/sync_tests/test_group_by.py: 100%
105 statements
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2024-08-27 18:25 +0000
1import pytest
2from syrupy.assertion import SnapshotAssertion
4import prisma
5from prisma import Prisma
6from lib.testing import async_fixture
7from prisma.types import SortOrder
9# TODO: test all types
10# TODO: test working with the results
13@async_fixture(autouse=True, scope='session')
14def create_test_data(client: Prisma) -> None:
15 create = client.profile.create
16 create(
17 {
18 'description': 'from scotland',
19 'country': 'Scotland',
20 'city': 'Edinburgh',
21 'views': 250,
22 'user': {'create': {'name': 'Tegan'}},
23 }
24 )
26 for _ in range(12):
27 create(
28 {
29 'description': 'description',
30 'country': 'Denmark',
31 'views': 500,
32 'user': {'create': {'name': 'Robert'}},
33 }
34 )
36 for _ in range(8):
37 create(
38 {
39 'description': 'description',
40 'country': 'Denmark',
41 'city': 'Copenhagen',
42 'views': 1000,
43 'user': {'create': {'name': 'Robert'}},
44 }
45 )
47 types_create = client.types.create
48 for i in range(10):
49 types_create(
50 {
51 'integer': i,
52 },
53 )
56@pytest.mark.persist_data
57def test_group_by(snapshot: SnapshotAssertion, client: Prisma) -> None:
58 """Basic test grouping by 1 field with no additional filters"""
59 assert (
60 client.user.group_by(
61 ['name'],
62 order={
63 'name': 'asc',
64 },
65 )
66 == snapshot
67 )
68 assert (
69 client.profile.group_by(
70 ['country'],
71 order={
72 'country': 'asc',
73 },
74 )
75 == snapshot
76 )
79@pytest.mark.persist_data
80def test_docs_example(snapshot: SnapshotAssertion, client: Prisma) -> None:
81 """Test the example given in the Prisma documentation:
82 https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#groupby
83 """
84 results = client.profile.group_by(
85 by=['country', 'city'],
86 count={
87 '_all': True,
88 'city': True,
89 },
90 sum={
91 'views': True,
92 },
93 order={
94 'country': 'desc',
95 },
96 having={
97 'views': {
98 '_avg': {
99 'gt': 200,
100 },
101 },
102 },
103 )
104 assert results == snapshot
107@pytest.mark.persist_data
108@pytest.mark.parametrize('order', ['asc', 'desc'])
109def test_order(snapshot: SnapshotAssertion, client: Prisma, order: SortOrder) -> None:
110 """Test ordering results by a grouped field"""
111 assert client.profile.group_by(['country'], order={'country': order}) == snapshot
114@pytest.mark.persist_data
115def test_order_list(snapshot: SnapshotAssertion, client: Prisma) -> None:
116 """Test ordering results by a list of grouped fields"""
117 results = client.profile.group_by(
118 by=['country', 'city'],
119 order=[
120 {'country': 'asc'},
121 {'city': 'desc'},
122 ],
123 )
124 # we have to apply this sorted operation as SQlite and PostgreSQL
125 # have different default behaviour for sorting by nulls
126 # and we don't support changing it yet
127 results = sorted(results, key=lambda p: p.get('city') is not None)
128 assert results == snapshot
131@pytest.mark.persist_data
132def test_order_multiple_fields(client: Prisma) -> None:
133 """Test ordering results by multiple fields is not support"""
134 with pytest.raises(prisma.errors.DataError):
135 client.profile.group_by(
136 ['country', 'city'],
137 order={
138 'city': 'desc',
139 'country': 'asc',
140 },
141 )
144@pytest.mark.persist_data
145def test_order_mismatched_arguments(client: Prisma) -> None:
146 """The order argument only accepts fields that are being grouped"""
147 with pytest.raises(prisma.errors.InputError) as exc:
148 client.profile.group_by(
149 ['city'],
150 order={ # pyright: ignore
151 'country': 'asc',
152 },
153 )
155 assert exc.match(
156 r'Every field used for orderBy must be included in the by-arguments of the query\. ' r'Missing fields: country'
157 )
160@pytest.mark.persist_data
161@pytest.mark.parametrize('order', ['asc', 'desc'])
162def test_take(
163 snapshot: SnapshotAssertion,
164 client: Prisma,
165 order: SortOrder,
166) -> None:
167 """Take argument limits number of records returned"""
168 assert (
169 client.profile.group_by(
170 ['country'],
171 take=1,
172 order={'country': order},
173 )
174 == snapshot
175 )
178@pytest.mark.persist_data
179def test_take_missing_order_argument(client: Prisma) -> None:
180 """The order argument must be provided to use take"""
181 with pytest.raises(TypeError) as exc:
182 client.profile.group_by(['country'], take=1)
184 assert exc.match("Missing argument: 'order' which is required when 'take' is present")
187@pytest.mark.persist_data
188@pytest.mark.parametrize('order', ['asc', 'desc'])
189def test_skip(
190 snapshot: SnapshotAssertion,
191 client: Prisma,
192 order: SortOrder,
193) -> None:
194 """Skipping grouped records"""
195 assert (
196 client.profile.group_by(
197 ['country'],
198 skip=1,
199 order={'country': order},
200 )
201 == snapshot
202 )
205@pytest.mark.persist_data
206def test_skip_missing_order_argument(client: Prisma) -> None:
207 """The order argument must be provided to use skip"""
208 with pytest.raises(TypeError) as exc:
209 client.profile.group_by(['country'], skip=1)
211 assert exc.match("Missing argument: 'order' which is required when 'skip' is present")
214@pytest.mark.persist_data
215def test_where(client: Prisma) -> None:
216 """Where argument correctly filters records"""
217 results = client.profile.group_by(
218 ['country'],
219 where={
220 'country': 'Denmark',
221 },
222 order={
223 'country': 'asc',
224 },
225 )
226 assert len(results) == 1
227 assert results[0].get('country') == 'Denmark'
229 results = client.profile.group_by(
230 ['country'],
231 where={
232 'description': {
233 'contains': 'scotland',
234 },
235 },
236 order={
237 'country': 'asc',
238 },
239 )
240 assert len(results) == 1
241 assert results[0].get('country') == 'Scotland'
244@pytest.mark.persist_data
245def test_having_missing_field_in_by(client: Prisma) -> None:
246 """Having filters must be an aggregation filter or be included in by"""
247 with pytest.raises(prisma.errors.InputError) as exc:
248 client.profile.group_by(
249 by=['country'],
250 count=True,
251 having={
252 'views': {
253 'gt': 50,
254 },
255 },
256 order={
257 'country': 'asc',
258 },
259 )
261 assert exc.match(
262 'Input error. Every field used in `having` filters must either be an aggregation filter '
263 'or be included in the selection of the query. Missing fields: views'
264 )
267@pytest.mark.persist_data
268def test_having_aggregation(snapshot: SnapshotAssertion, client: Prisma) -> None:
269 """Having aggregation filters records correctly"""
270 assert (
271 client.profile.group_by(
272 by=['country'],
273 count=True,
274 having={
275 'views': {
276 '_avg': {
277 'gt': 600,
278 }
279 }
280 },
281 order={
282 'country': 'asc',
283 },
284 )
285 == snapshot
286 )
287 assert (
288 client.profile.group_by(
289 by=['country'],
290 count=True,
291 having={
292 'views': {
293 '_avg': {
294 'lt': 600,
295 }
296 }
297 },
298 order={
299 'country': 'asc',
300 },
301 )
302 == snapshot
303 )
306@pytest.mark.persist_data
307def test_having_aggregation_nested(snapshot: SnapshotAssertion, client: Prisma) -> None:
308 """Having aggregation filters nested within statements correctly filters records"""
309 results = client.profile.group_by(
310 by=['country'],
311 count=True,
312 having={
313 'OR': [
314 {
315 'views': {
316 '_avg': {
317 'equals': 1000,
318 },
319 },
320 },
321 {
322 'views': {
323 '_sum': {
324 'equals': 250,
325 },
326 },
327 },
328 ],
329 },
330 order={
331 'country': 'asc',
332 },
333 )
334 assert results == snapshot
336 results = client.profile.group_by(
337 by=['country'],
338 count=True,
339 having={
340 'OR': [
341 {
342 'views': {
343 '_avg': {
344 'equals': 700,
345 },
346 },
347 },
348 {
349 'views': {
350 '_sum': {
351 'equals': 250,
352 },
353 },
354 },
355 ],
356 },
357 order={
358 'country': 'asc',
359 },
360 )
361 assert results == snapshot
363 results = client.profile.group_by(
364 by=['country'],
365 count=True,
366 having={
367 'OR': [
368 {
369 'views': {
370 '_avg': {
371 'equals': 700,
372 },
373 },
374 },
375 {
376 'views': {
377 '_sum': {
378 'equals': 250,
379 },
380 },
381 'NOT': [
382 {
383 'views': {
384 '_min': {
385 'equals': 250,
386 },
387 },
388 },
389 ],
390 },
391 ],
392 },
393 order={
394 'country': 'asc',
395 },
396 )
397 assert results == snapshot
400@pytest.mark.persist_data
401def test_count(snapshot: SnapshotAssertion, client: Prisma) -> None:
402 """Counting records"""
403 assert (
404 client.profile.group_by(
405 ['country'],
406 count=True,
407 order={
408 'country': 'asc',
409 },
410 )
411 == snapshot
412 )
413 assert (
414 client.profile.group_by(
415 ['country'],
416 count={
417 '_all': True,
418 },
419 order={
420 'country': 'asc',
421 },
422 )
423 == snapshot
424 )
425 assert (
426 client.profile.group_by(
427 ['country'],
428 count={
429 'city': True,
430 },
431 order={
432 'country': 'asc',
433 },
434 )
435 == snapshot
436 )
437 assert (
438 client.profile.group_by(
439 ['country'],
440 count={
441 'city': True,
442 'country': True,
443 },
444 order={
445 'country': 'asc',
446 },
447 )
448 == snapshot
449 )
452@pytest.mark.persist_data
453def test_avg(snapshot: SnapshotAssertion, client: Prisma) -> None:
454 """Getting the average of records"""
455 assert (
456 client.profile.group_by(
457 ['country'],
458 avg={'views': True},
459 order={'country': 'asc'},
460 )
461 == snapshot
462 )
463 assert (
464 client.types.group_by(
465 ['string'],
466 avg={'integer': True, 'bigint': True},
467 order={'string': 'asc'},
468 )
469 == snapshot
470 )
473@pytest.mark.persist_data
474def test_sum(snapshot: SnapshotAssertion, client: Prisma) -> None:
475 """Getting the sum of records"""
476 assert (
477 client.profile.group_by(
478 ['country'],
479 sum={
480 'views': True,
481 },
482 order={
483 'country': 'asc',
484 },
485 )
486 == snapshot
487 )
490@pytest.mark.persist_data
491def test_min(snapshot: SnapshotAssertion, client: Prisma) -> None:
492 """Getting the minimum value of records"""
493 assert (
494 client.profile.group_by(
495 ['country'],
496 min={
497 'views': True,
498 },
499 order={
500 'country': 'asc',
501 },
502 )
503 == snapshot
504 )
507@pytest.mark.persist_data
508def test_max(snapshot: SnapshotAssertion, client: Prisma) -> None:
509 """Getting the maximum value of records"""
510 assert (
511 client.profile.group_by(
512 ['country'],
513 max={
514 'views': True,
515 },
516 order={
517 'country': 'asc',
518 },
519 )
520 == snapshot
521 )