Esempio n. 1
0
class TestModelCreate(test.TestCase):
    async def test_save_generated(self):
        mdl = await Tournament.create(name="Test")
        mdl2 = await Tournament.get(id=mdl.id)
        self.assertEqual(mdl, mdl2)

    async def test_save_non_generated(self):
        mdl = await UUIDFkRelatedNullModel.create(name="Test")
        mdl2 = await UUIDFkRelatedNullModel.get(id=mdl.id)
        self.assertEqual(mdl, mdl2)

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_save_generated_custom_id(self):
        cid = 12345
        mdl = await Tournament.create(id=cid, name="Test")
        self.assertEqual(mdl.id, cid)
        mdl2 = await Tournament.get(id=cid)
        self.assertEqual(mdl, mdl2)

    async def test_save_non_generated_custom_id(self):
        cid = uuid4()
        mdl = await UUIDFkRelatedNullModel.create(id=cid, name="Test")
        self.assertEqual(mdl.id, cid)
        mdl2 = await UUIDFkRelatedNullModel.get(id=cid)
        self.assertEqual(mdl, mdl2)

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_save_generated_duplicate_custom_id(self):
        cid = 12345
        await Tournament.create(id=cid, name="TestOriginal")
        with self.assertRaises(IntegrityError):
            await Tournament.create(id=cid, name="Test")

    async def test_save_non_generated_duplicate_custom_id(self):
        cid = uuid4()
        await UUIDFkRelatedNullModel.create(id=cid, name="TestOriginal")
        with self.assertRaises(IntegrityError):
            await UUIDFkRelatedNullModel.create(id=cid, name="Test")

    async def test_clone_pk_required_error(self):
        mdl = await RequiredPKModel.create(id="A", name="name_a")
        with self.assertRaises(ParamsError):
            mdl.clone()

    async def test_clone_pk_required(self):
        mdl = await RequiredPKModel.create(id="A", name="name_a")
        mdl2 = mdl.clone(pk="B")
        await mdl2.save()
        mdls = list(await RequiredPKModel.all())
        self.assertEqual(len(mdls), 2)

    async def test_implicit_clone_pk_required_none(self):
        mdl = await RequiredPKModel.create(id="A", name="name_a")
        mdl.pk = None
        with self.assertRaises(ValidationError):
            await mdl.save()
class TestOrderByNested(test.TestCase):
    @test.requireCapability(dialect=NotEQ("oracle"))
    async def test_basic(self):
        await Event.create(
            name="Event 1", tournament=await Tournament.create(name="Tournament 1", desc="B")
        )
        await Event.create(
            name="Event 2", tournament=await Tournament.create(name="Tournament 2", desc="A")
        )

        self.assertEqual(
            await Event.all().order_by("-name").values("name"),
            [{"name": "Event 2"}, {"name": "Event 1"}],
        )

        self.assertEqual(
            await Event.all().prefetch_related("tournament").values("tournament__desc"),
            [{"tournament__desc": "B"}, {"tournament__desc": "A"}],
        )

        self.assertEqual(
            await Event.all()
            .prefetch_related("tournament")
            .order_by("tournament__desc")
            .values("tournament__desc"),
            [{"tournament__desc": "A"}, {"tournament__desc": "B"}],
        )
Esempio n. 3
0
class TestConcurrencyIsolated(test.IsolatedTestCase):
    async def test_concurrency_read(self):
        await Tournament.create(name="Test")
        tour1 = await Tournament.first()
        all_read = await asyncio.gather(
            *[Tournament.first() for _ in range(100)])
        self.assertEqual(all_read, [tour1 for _ in range(100)])

    async def test_concurrency_create(self):
        all_write = await asyncio.gather(
            *[Tournament.create(name="Test") for _ in range(100)])
        all_read = await Tournament.all()
        self.assertEqual(set(all_write), set(all_read))

    async def create_trans_concurrent(self):
        async with in_transaction():
            await asyncio.gather(
                *[Tournament.create(name="Test") for _ in range(100)])

    async def test_nonconcurrent_get_or_create(self):
        unas = [await UniqueName.get_or_create(name="c") for _ in range(10)]
        una_created = [una[1] for una in unas if una[1] is True]
        self.assertEqual(len(una_created), 1)
        for una in unas:
            self.assertEqual(una[0], unas[0][0])

    @test.skipIf(sys.version_info < (3, 7),
                 "aiocontextvars backport not handling this well")
    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_concurrent_get_or_create(self):
        unas = await asyncio.gather(
            *[UniqueName.get_or_create(name="d") for _ in range(10)])
        una_created = [una[1] for una in unas if una[1] is True]
        self.assertEqual(len(una_created), 1)
        for una in unas:
            self.assertEqual(una[0], unas[0][0])

    @test.skipIf(sys.version_info < (3, 7),
                 "aiocontextvars backport not handling this well")
    @test.requireCapability(supports_transactions=True)
    async def test_concurrency_transactions_concurrent(self):
        await asyncio.gather(
            *[self.create_trans_concurrent() for _ in range(10)])
        count = await Tournament.all().count()
        self.assertEqual(count, 1000)

    async def create_trans(self):
        async with in_transaction():
            await Tournament.create(name="Test")

    @test.skipIf(sys.version_info < (3, 7),
                 "aiocontextvars backport not handling this well")
    @test.requireCapability(supports_transactions=True)
    async def test_concurrency_transactions(self):
        await asyncio.gather(*[self.create_trans() for _ in range(100)])
        count = await Tournament.all().count()
        self.assertEqual(count, 100)
Esempio n. 4
0
class TestExplain(test.TestCase):
    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_explain(self):
        # NOTE: we do not provide any guarantee on the format of the value
        # returned by `.explain()`, as it heavily depends on the database.
        # This test merely checks that one is able to run `.explain()`
        # without errors for each backend.
        plan = await Tournament.all().explain()
        # This should have returned *some* information.
        self.assertGreater(len(str(plan)), 20)
Esempio n. 5
0
class TestFuzz(test.TestCase):
    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_char_fuzz(self):
        for char in DODGY_STRINGS:
            # print(repr(char))
            if "\x00" in char and self._db.capabilities.dialect in [
                    "postgres"
            ]:
                # PostgreSQL doesn't support null values as text. Ever. So skip these.
                continue

            # Create
            obj1 = await CharFields.create(char=char)

            # Get-by-pk, and confirm that reading is correct
            obj2 = await CharFields.get(pk=obj1.pk)
            self.assertEqual(char, obj2.char)

            # Update data using a queryset, confirm that update is correct
            await CharFields.filter(pk=obj1.pk).update(char="a")
            await CharFields.filter(pk=obj1.pk).update(char=char)
            obj3 = await CharFields.get(pk=obj1.pk)
            self.assertEqual(char, obj3.char)

            # Filter by value in queryset, and confirm that it fetched the right one
            obj4 = await CharFields.get(pk=obj1.pk, char=char)
            self.assertEqual(obj1.pk, obj4.pk)
            self.assertEqual(char, obj4.char)

            # LIKE statements are not strict, so require all of these to match
            obj5 = await CharFields.get(
                pk=obj1.pk,
                char__startswith=char,
                char__endswith=char,
                char__contains=char,
                char__istartswith=char,
                char__iendswith=char,
                char__icontains=char,
            )
            self.assertEqual(obj1.pk, obj5.pk)
            self.assertEqual(char, obj5.char)
Esempio n. 6
0
class TestBulk(test.TruncationTestCase):
    async def test_bulk_create(self):
        await UniqueName.bulk_create([UniqueName() for _ in range(1000)])
        all_ = await UniqueName.all().values("id", "name")
        inc = all_[0]["id"]
        self.assertListSortEqual(all_, [{
            "id": val + inc,
            "name": None
        } for val in range(1000)],
                                 sorted_key="id")

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_update_fields(self):
        await UniqueName.bulk_create([UniqueName(name="name")])
        await UniqueName.bulk_create(
            [UniqueName(name="name", optional="optional")],
            update_fields=["optional"],
            on_conflict=["name"],
        )
        all_ = await UniqueName.all().values("name", "optional")
        self.assertListSortEqual(all_, [{
            "name": "name",
            "optional": "optional"
        }])

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_more_that_one_update_fields(self):
        await UniqueName.bulk_create([UniqueName(name="name")])
        await UniqueName.bulk_create(
            [
                UniqueName(name="name",
                           optional="optional",
                           other_optional="other_optional")
            ],
            update_fields=["optional", "other_optional"],
            on_conflict=["name"],
        )
        all_ = await UniqueName.all().values("name", "optional",
                                             "other_optional")
        self.assertListSortEqual(all_, [{
            "name": "name",
            "optional": "optional",
            "other_optional": "other_optional"
        }])

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_with_batch_size(self):
        await UniqueName.bulk_create(
            [UniqueName(id=id_ + 1) for id_ in range(1000)], batch_size=100)
        all_ = await UniqueName.all().values("id", "name")
        self.assertListSortEqual(all_, [{
            "id": val + 1,
            "name": None
        } for val in range(1000)],
                                 sorted_key="id")

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_with_specified(self):
        await UniqueName.bulk_create(
            [UniqueName(id=id_) for id_ in range(1000, 2000)])
        all_ = await UniqueName.all().values("id", "name")
        self.assertListSortEqual(all_, [{
            "id": id_,
            "name": None
        } for id_ in range(1000, 2000)],
                                 sorted_key="id")

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_mix_specified(self):
        await UniqueName.bulk_create(
            [UniqueName(id=id_) for id_ in range(10000, 11000)] +
            [UniqueName() for _ in range(1000)])

        all_ = await UniqueName.all().values("id", "name")
        self.assertEqual(len(all_), 2000)

        self.assertListSortEqual(all_[:1000], [{
            "id": id_,
            "name": None
        } for id_ in range(10000, 11000)],
                                 sorted_key="id")
        inc = all_[1000]["id"]
        self.assertListSortEqual(all_[1000:], [{
            "id": val + inc,
            "name": None
        } for val in range(1000)],
                                 sorted_key="id")

    async def test_bulk_create_uuidpk(self):
        await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)])
        res = await UUIDPkModel.all().values_list("id", flat=True)
        self.assertEqual(len(res), 1000)
        self.assertIsInstance(res[0], UUID)

    @test.requireCapability(supports_transactions=True)
    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_in_transaction(self):
        async with in_transaction():
            await UniqueName.bulk_create([UniqueName() for _ in range(1000)])
        all_ = await UniqueName.all().values("id", "name")
        inc = all_[0]["id"]
        self.assertEqual(all_, [{
            "id": val + inc,
            "name": None
        } for val in range(1000)])

    @test.requireCapability(supports_transactions=True)
    async def test_bulk_create_uuidpk_in_transaction(self):
        async with in_transaction():
            await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)])
        res = await UUIDPkModel.all().values_list("id", flat=True)
        self.assertEqual(len(res), 1000)
        self.assertIsInstance(res[0], UUID)

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_fail(self):
        with self.assertRaises(IntegrityError):
            await UniqueName.bulk_create(
                [UniqueName(name=str(i)) for i in range(10)] +
                [UniqueName(name=str(i)) for i in range(10)])

    async def test_bulk_create_uuidpk_fail(self):
        val = uuid4()
        with self.assertRaises(IntegrityError):
            await UUIDPkModel.bulk_create(
                [UUIDPkModel(id=val) for _ in range(10)])

    @test.requireCapability(supports_transactions=True, dialect=NotEQ("mssql"))
    async def test_bulk_create_in_transaction_fail(self):
        with self.assertRaises(IntegrityError):
            async with in_transaction():
                await UniqueName.bulk_create(
                    [UniqueName(name=str(i)) for i in range(10)] +
                    [UniqueName(name=str(i)) for i in range(10)])

    @test.requireCapability(supports_transactions=True)
    async def test_bulk_create_uuidpk_in_transaction_fail(self):
        val = uuid4()
        with self.assertRaises(IntegrityError):
            async with in_transaction():
                await UUIDPkModel.bulk_create(
                    [UUIDPkModel(id=val) for _ in range(10)])

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_create_ignore_conflicts(self):
        name1 = UniqueName(name="name1")
        name2 = UniqueName(name="name2")
        await UniqueName.bulk_create([name1, name2])
        await UniqueName.bulk_create([name1, name2], ignore_conflicts=True)
        with self.assertRaises(IntegrityError):
            await UniqueName.bulk_create([name1, name2])
Esempio n. 7
0
class TestQueryset(test.TestCase):
    async def asyncSetUp(self):
        await super().asyncSetUp()
        # Build large dataset
        self.intfields = [
            await IntFields.create(intnum=val) for val in range(10, 100, 3)
        ]
        self.db = connections.get("models")

    async def test_all_count(self):
        self.assertEqual(await IntFields.all().count(), 30)
        self.assertEqual(await IntFields.filter(intnum_null=80).count(), 0)

    async def test_exists(self):
        ret = await IntFields.filter(intnum=0).exists()
        self.assertFalse(ret)

        ret = await IntFields.filter(intnum=10).exists()
        self.assertTrue(ret)

        ret = await IntFields.filter(intnum__gt=10).exists()
        self.assertTrue(ret)

        ret = await IntFields.filter(intnum__lt=10).exists()
        self.assertFalse(ret)

    async def test_limit_count(self):
        self.assertEqual(await IntFields.all().limit(10).count(), 10)

    async def test_limit_negative(self):
        with self.assertRaisesRegex(ParamsError,
                                    "Limit should be non-negative number"):
            await IntFields.all().limit(-10)

    async def test_offset_count(self):
        self.assertEqual(await IntFields.all().offset(10).count(), 20)

    async def test_offset_negative(self):
        with self.assertRaisesRegex(ParamsError,
                                    "Offset should be non-negative number"):
            await IntFields.all().offset(-10)

    async def test_join_count(self):
        tour = await Tournament.create(name="moo")
        await MinRelation.create(tournament=tour)

        self.assertEqual(await MinRelation.all().count(), 1)
        self.assertEqual(
            await MinRelation.filter(tournament__id=tour.id).count(), 1)

    async def test_modify_dataset(self):
        # Modify dataset
        rows_affected = await IntFields.filter(intnum__gte=70
                                               ).update(intnum_null=80)
        self.assertEqual(rows_affected, 10)
        self.assertEqual(await IntFields.filter(intnum_null=80).count(), 10)
        self.assertEqual(
            await IntFields.filter(intnum_null__isnull=True).count(), 20)
        await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1)
        self.assertEqual(await IntFields.filter(intnum_null=None).count(), 0)
        self.assertEqual(await IntFields.filter(intnum_null=-1).count(), 20)

    async def test_distinct(self):
        # Test distinct
        await IntFields.filter(intnum__gte=70).update(intnum_null=80)
        await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1)

        self.assertEqual(
            await
            IntFields.all().order_by("intnum_null").distinct().values_list(
                "intnum_null", flat=True),
            [-1, 80],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum_null").distinct().values(
                "intnum_null"),
            [{
                "intnum_null": -1
            }, {
                "intnum_null": 80
            }],
        )

    async def test_limit_offset_values_list(self):
        # Test limit/offset/ordering values_list
        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(10).values_list(
                "intnum", flat=True),
            [10, 13, 16, 19, 22, 25, 28, 31, 34, 37],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(10).offset(
                10).values_list("intnum", flat=True),
            [40, 43, 46, 49, 52, 55, 58, 61, 64, 67],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(10).offset(
                20).values_list("intnum", flat=True),
            [70, 73, 76, 79, 82, 85, 88, 91, 94, 97],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(10).offset(
                30).values_list("intnum", flat=True),
            [],
        )

        self.assertEqual(
            await IntFields.all().order_by("-intnum").limit(10).values_list(
                "intnum", flat=True),
            [97, 94, 91, 88, 85, 82, 79, 76, 73, 70],
        )

        self.assertEqual(
            await
            IntFields.all().order_by("intnum").limit(10).filter(intnum__gte=40
                                                                ).values_list(
                                                                    "intnum",
                                                                    flat=True),
            [40, 43, 46, 49, 52, 55, 58, 61, 64, 67],
        )

    async def test_limit_offset_values(self):
        # Test limit/offset/ordering values
        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(5).values("intnum"),
            [{
                "intnum": 10
            }, {
                "intnum": 13
            }, {
                "intnum": 16
            }, {
                "intnum": 19
            }, {
                "intnum": 22
            }],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(5).offset(
                10).values("intnum"),
            [{
                "intnum": 40
            }, {
                "intnum": 43
            }, {
                "intnum": 46
            }, {
                "intnum": 49
            }, {
                "intnum": 52
            }],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(5).offset(
                30).values("intnum"), [])

        self.assertEqual(
            await
            IntFields.all().order_by("-intnum").limit(5).values("intnum"),
            [{
                "intnum": 97
            }, {
                "intnum": 94
            }, {
                "intnum": 91
            }, {
                "intnum": 88
            }, {
                "intnum": 85
            }],
        )

        self.assertEqual(
            await IntFields.all().order_by("intnum").limit(5).filter(
                intnum__gte=40).values("intnum"),
            [{
                "intnum": 40
            }, {
                "intnum": 43
            }, {
                "intnum": 46
            }, {
                "intnum": 49
            }, {
                "intnum": 52
            }],
        )

    async def test_in_bulk(self):
        id_list = [
            item.pk for item in await IntFields.all().only("id").limit(2)
        ]
        ret = await IntFields.in_bulk(id_list=id_list)
        self.assertEqual(list(ret.keys()), id_list)

    async def test_first(self):
        # Test first
        self.assertEqual(
            (await
             IntFields.all().order_by("intnum").filter(intnum__gte=40
                                                       ).first()).intnum, 40)
        self.assertEqual(
            (await IntFields.all().order_by("intnum").filter(
                intnum__gte=40).first().values())["intnum"],
            40,
        )
        self.assertEqual(
            (await IntFields.all().order_by("intnum").filter(
                intnum__gte=40).first().values_list())[1],
            40,
        )

        self.assertEqual(
            await
            IntFields.all().order_by("intnum").filter(intnum__gte=400).first(),
            None)
        self.assertEqual(
            await IntFields.all().order_by("intnum").filter(intnum__gte=400
                                                            ).first().values(),
            None)
        self.assertEqual(
            await
            IntFields.all().order_by("intnum").filter(intnum__gte=400
                                                      ).first().values_list(),
            None,
        )

    async def test_get_or_none(self):
        self.assertEqual((await IntFields.all().get_or_none(intnum=40)).intnum,
                         40)
        self.assertEqual(
            (await
             IntFields.all().get_or_none(intnum=40).values())["intnum"], 40)
        self.assertEqual(
            (await
             IntFields.all().get_or_none(intnum=40).values_list())[1], 40)

        self.assertEqual(
            await
            IntFields.all().order_by("intnum").get_or_none(intnum__gte=400),
            None)

        self.assertEqual(
            await
            IntFields.all().order_by("intnum").get_or_none(intnum__gte=400
                                                           ).values(), None)

        self.assertEqual(
            await
            IntFields.all().order_by("intnum").get_or_none(intnum__gte=400
                                                           ).values_list(),
            None,
        )

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40
                                                                 )

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40
                                                                 ).values()

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().order_by("intnum").get_or_none(
                intnum__gte=40).values_list()

    async def test_get(self):
        await IntFields.filter(intnum__gte=70).update(intnum_null=80)

        # Test get
        self.assertEqual((await IntFields.all().get(intnum=40)).intnum, 40)
        self.assertEqual(
            (await IntFields.all().get(intnum=40).values())["intnum"], 40)
        self.assertEqual((await
                          IntFields.all().get(intnum=40).values_list())[1], 40)

        self.assertEqual(
            (await
             IntFields.all().all().all().all().all().get(intnum=40)).intnum,
            40)
        self.assertEqual(
            (await
             IntFields.all().all().all().all().all().get(intnum=40
                                                         ).values())["intnum"],
            40)
        self.assertEqual(
            (await
             IntFields.all().all().all().all().all().get(intnum=40
                                                         ).values_list())[1],
            40)

        self.assertEqual((await IntFields.get(intnum=40)).intnum, 40)
        self.assertEqual((await IntFields.get(intnum=40).values())["intnum"],
                         40)
        self.assertEqual((await IntFields.get(intnum=40).values_list())[1], 40)

        with self.assertRaises(DoesNotExist):
            await IntFields.all().get(intnum=41)

        with self.assertRaises(DoesNotExist):
            await IntFields.all().get(intnum=41).values()

        with self.assertRaises(DoesNotExist):
            await IntFields.all().get(intnum=41).values_list()

        with self.assertRaises(DoesNotExist):
            await IntFields.get(intnum=41)

        with self.assertRaises(DoesNotExist):
            await IntFields.get(intnum=41).values()

        with self.assertRaises(DoesNotExist):
            await IntFields.get(intnum=41).values_list()

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().get(intnum_null=80)

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().get(intnum_null=80).values()

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.all().get(intnum_null=80).values_list()

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.get(intnum_null=80)

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.get(intnum_null=80).values()

        with self.assertRaises(MultipleObjectsReturned):
            await IntFields.get(intnum_null=80).values_list()

    async def test_delete(self):
        # Test delete
        await (await IntFields.get(intnum=40)).delete()

        with self.assertRaises(DoesNotExist):
            await IntFields.get(intnum=40)

        self.assertEqual(await IntFields.all().count(), 29)

        rows_affected = (await
                         IntFields.all().order_by("intnum").limit(10).filter(
                             intnum__gte=70).delete())
        self.assertEqual(rows_affected, 10)

        self.assertEqual(await IntFields.all().count(), 19)

    @test.requireCapability(support_update_limit_order_by=True)
    async def test_delete_limit(self):
        await IntFields.all().limit(1).delete()
        self.assertEqual(await IntFields.all().count(), 29)

    @test.requireCapability(support_update_limit_order_by=True)
    async def test_delete_limit_order_by(self):
        await IntFields.all().limit(1).order_by("-id").delete()
        self.assertEqual(await IntFields.all().count(), 29)
        with self.assertRaises(DoesNotExist):
            await IntFields.get(intnum=97)

    async def test_async_iter(self):
        counter = 0
        async for _ in IntFields.all():
            counter += 1

        self.assertEqual(await IntFields.all().count(), counter)

    async def test_update_basic(self):
        obj0 = await IntFields.create(intnum=2147483647)
        await IntFields.filter(id=obj0.id).update(intnum=2147483646)
        obj = await IntFields.get(id=obj0.id)
        self.assertEqual(obj.intnum, 2147483646)
        self.assertEqual(obj.intnum_null, None)

    async def test_update_f_expression(self):
        obj0 = await IntFields.create(intnum=2147483647)
        await IntFields.filter(id=obj0.id).update(intnum=F("intnum") - 1)
        obj = await IntFields.get(id=obj0.id)
        self.assertEqual(obj.intnum, 2147483646)

    async def test_update_badparam(self):
        obj0 = await IntFields.create(intnum=2147483647)
        with self.assertRaisesRegex(FieldError, "Unknown keyword argument"):
            await IntFields.filter(id=obj0.id).update(badparam=1)

    async def test_update_pk(self):
        obj0 = await IntFields.create(intnum=2147483647)
        with self.assertRaisesRegex(IntegrityError,
                                    "is PK and can not be updated"):
            await IntFields.filter(id=obj0.id).update(id=1)

    async def test_update_virtual(self):
        tour = await Tournament.create(name="moo")
        obj0 = await MinRelation.create(tournament=tour)
        with self.assertRaisesRegex(FieldError,
                                    "is virtual and can not be updated"):
            await MinRelation.filter(id=obj0.id).update(participants=[])

    async def test_bad_ordering(self):
        with self.assertRaisesRegex(
                FieldError, "Unknown field moo1fip for model IntFields"):
            await IntFields.all().order_by("moo1fip")

    async def test_duplicate_values(self):
        with self.assertRaisesRegex(FieldError, "Duplicate key intnum"):
            await IntFields.all().values("intnum", "intnum")

    async def test_duplicate_values_list(self):
        await IntFields.all().values_list("intnum", "intnum")

    async def test_duplicate_values_kw(self):
        with self.assertRaisesRegex(FieldError, "Duplicate key intnum"):
            await IntFields.all().values("intnum", intnum="intnum_null")

    async def test_duplicate_values_kw_badmap(self):
        with self.assertRaisesRegex(
                FieldError, 'Unknown field "intnum2" for model "IntFields"'):
            await IntFields.all().values(intnum="intnum2")

    async def test_bad_values(self):
        with self.assertRaisesRegex(
                FieldError, 'Unknown field "int2num" for model "IntFields"'):
            await IntFields.all().values("int2num")

    async def test_bad_values_list(self):
        with self.assertRaisesRegex(
                FieldError, 'Unknown field "int2num" for model "IntFields"'):
            await IntFields.all().values_list("int2num")

    async def test_many_flat_values_list(self):
        with self.assertRaisesRegex(
                TypeError,
                "You can flat value_list only if contains one field"):
            await IntFields.all().values_list("intnum",
                                              "intnum_null",
                                              flat=True)

    async def test_all_flat_values_list(self):
        with self.assertRaisesRegex(
                TypeError,
                "You can flat value_list only if contains one field"):
            await IntFields.all().values_list(flat=True)

    async def test_all_values_list(self):
        data = await IntFields.all().order_by("id").values_list()
        self.assertEqual(data[2], (self.intfields[2].id, 16, None))

    async def test_all_values(self):
        data = await IntFields.all().order_by("id").values()
        self.assertEqual(data[2], {
            "id": self.intfields[2].id,
            "intnum": 16,
            "intnum_null": None
        })

    async def test_order_by_bad_value(self):
        with self.assertRaisesRegex(FieldError,
                                    "Unknown field badid for model IntFields"):
            await IntFields.all().order_by("badid").values_list()

    async def test_annotate_order_expression(self):
        data = (await
                IntFields.annotate(idp=F("id") + 1
                                   ).order_by("-idp").first().values_list(
                                       "id", "idp"))
        self.assertEqual(data[0] + 1, data[1])

    async def test_annotate_expression_filter(self):
        count = await IntFields.annotate(intnum=F("intnum") + 1
                                         ).filter(intnum__gt=30).count()
        self.assertEqual(count, 23)

    async def test_get_raw_sql(self):
        sql = IntFields.all().sql()
        self.assertRegex(sql, r"^SELECT.+FROM.+")

    @test.requireCapability(support_index_hint=True)
    async def test_force_index(self):
        sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql()
        self.assertEqual(
            sql,
            "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_again = IntFields.filter(
            pk=1).only("id").force_index("index_name").sql()
        self.assertEqual(
            sql_again,
            "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
        )

    @test.requireCapability(support_index_hint=True)
    async def test_force_index_avaiable_in_more_query(self):
        sql_ValuesQuery = IntFields.filter(
            pk=1).force_index("index_name").values("id").sql()
        self.assertEqual(
            sql_ValuesQuery,
            "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_ValuesListQuery = (IntFields.filter(
            pk=1).force_index("index_name").values_list("id").sql())
        self.assertEqual(
            sql_ValuesListQuery,
            "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_CountQuery = IntFields.filter(
            pk=1).force_index("index_name").count().sql()
        self.assertEqual(
            sql_CountQuery,
            "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_ExistsQuery = IntFields.filter(
            pk=1).force_index("index_name").exists().sql()
        self.assertEqual(
            sql_ExistsQuery,
            "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1 LIMIT 1",
        )

    @test.requireCapability(support_index_hint=True)
    async def test_use_index(self):
        sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql()
        self.assertEqual(
            sql,
            "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_again = IntFields.filter(
            pk=1).only("id").use_index("index_name").sql()
        self.assertEqual(
            sql_again,
            "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
        )

    @test.requireCapability(support_index_hint=True)
    async def test_use_index_avaiable_in_more_query(self):
        sql_ValuesQuery = IntFields.filter(
            pk=1).use_index("index_name").values("id").sql()
        self.assertEqual(
            sql_ValuesQuery,
            "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_ValuesListQuery = IntFields.filter(
            pk=1).use_index("index_name").values_list("id").sql()
        self.assertEqual(
            sql_ValuesListQuery,
            "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_CountQuery = IntFields.filter(
            pk=1).use_index("index_name").count().sql()
        self.assertEqual(
            sql_CountQuery,
            "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1",
        )

        sql_ExistsQuery = IntFields.filter(
            pk=1).use_index("index_name").exists().sql()
        self.assertEqual(
            sql_ExistsQuery,
            "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1 LIMIT 1",
        )

    @test.requireCapability(support_for_update=True)
    async def test_select_for_update(self):
        sql1 = IntFields.filter(pk=1).only("id").select_for_update().sql()
        sql2 = IntFields.filter(pk=1).only("id").select_for_update(
            nowait=True).sql()
        sql3 = IntFields.filter(pk=1).only("id").select_for_update(
            skip_locked=True).sql()
        sql4 = IntFields.filter(pk=1).only("id").select_for_update(
            of=("intfields", )).sql()

        dialect = self.db.schema_generator.DIALECT
        if dialect == "postgres":
            self.assertEqual(
                sql1,
                'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE',
            )
            self.assertEqual(
                sql2,
                'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE NOWAIT',
            )
            self.assertEqual(
                sql3,
                'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE SKIP LOCKED',
            )
            self.assertEqual(
                sql4,
                'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE OF "intfields"',
            )
        elif dialect == "mysql":
            self.assertEqual(
                sql1,
                "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE",
            )
            self.assertEqual(
                sql2,
                "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE NOWAIT",
            )
            self.assertEqual(
                sql3,
                "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE SKIP LOCKED",
            )
            self.assertEqual(
                sql4,
                "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE OF `intfields`",
            )

    async def test_select_related(self):
        tournament = await Tournament.create(name="1")
        reporter = await Reporter.create(name="Reporter")
        event = await Event.create(name="1",
                                   tournament=tournament,
                                   reporter=reporter)
        event = await Event.all().select_related("tournament",
                                                 "reporter").get(pk=event.pk)
        self.assertEqual(event.tournament.pk, tournament.pk)
        self.assertEqual(event.reporter.pk, reporter.pk)

    async def test_select_related_with_two_same_models(self):
        parent_node = await Node.create(name="1")
        child_node = await Node.create(name="2")
        tree = await Tree.create(parent=parent_node, child=child_node)
        tree = await Tree.all().select_related("parent",
                                               "child").get(pk=tree.pk)
        self.assertEqual(tree.parent.pk, parent_node.pk)
        self.assertEqual(tree.parent.name, parent_node.name)
        self.assertEqual(tree.child.pk, child_node.pk)
        self.assertEqual(tree.child.name, child_node.name)

    @test.requireCapability(dialect="postgres")
    async def test_postgres_search(self):
        name = "hello world"
        await Tournament.create(name=name)
        ret = await Tournament.filter(name__search="hello").first()
        self.assertEqual(ret.name, name)

    async def test_subquery_select(self):
        t1 = await Tournament.create(name="1")
        ret = (await Tournament.filter(pk=t1.pk).annotate(
            ids=Subquery(Tournament.filter(pk=t1.pk).values("id"))
        ).values("ids", "id"))
        self.assertEqual(ret, [{"id": t1.pk, "ids": t1.pk}])

    async def test_subquery_filter(self):
        t1 = await Tournament.create(name="1")
        ret = await Tournament.filter(
            pk=Subquery(Tournament.filter(pk=t1.pk).values("id"))).first()
        self.assertEqual(ret, t1)

    async def test_raw_sql_count(self):
        t1 = await Tournament.create(name="1")
        ret = await Tournament.filter(pk=t1.pk
                                      ).annotate(count=RawSQL("count(*)")
                                                 ).values("count")
        self.assertEqual(ret, [{"count": 1}])

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_raw_sql_select(self):
        t1 = await Tournament.create(id=1, name="1")
        ret = (await
               Tournament.filter(pk=t1.pk).annotate(idp=RawSQL("id + 1")
                                                    ).filter(idp=2
                                                             ).values("idp"))
        self.assertEqual(ret, [{"idp": 2}])

    async def test_raw_sql_filter(self):
        ret = await Tournament.filter(pk=RawSQL("id + 1"))
        self.assertEqual(ret, [])

    async def test_annotation_field_priorior_to_model_field(self):
        # Sometimes, field name in annotates also exist in model field sets
        # and may need lift the former's priority in select query construction.
        t1 = await Tournament.create(name="1")
        ret = await Tournament.filter(pk=t1.pk).annotate(id=RawSQL("id + 1")
                                                         ).values("id")
        self.assertEqual(ret, [{"id": t1.pk + 1}])
Esempio n. 8
0
class TestValues(test.TestCase):
    async def test_values_related_fk(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        event2 = await Event.filter(name="Test"
                                    ).values("name", "tournament__name")
        self.assertEqual(event2[0], {
            "name": "Test",
            "tournament__name": "New Tournament"
        })

    async def test_values_list_related_fk(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        event2 = await Event.filter(name="Test"
                                    ).values_list("name", "tournament__name")
        self.assertEqual(event2[0], ("Test", "New Tournament"))

    async def test_values_related_rfk(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        tournament2 = await Tournament.filter(name="New Tournament"
                                              ).values("name", "events__name")
        self.assertEqual(tournament2[0], {
            "name": "New Tournament",
            "events__name": "Test"
        })

    async def test_values_list_related_rfk(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        tournament2 = await Tournament.filter(name="New Tournament"
                                              ).values_list(
                                                  "name", "events__name")
        self.assertEqual(tournament2[0], ("New Tournament", "Test"))

    async def test_values_related_m2m(self):
        tournament = await Tournament.create(name="New Tournament")
        event = await Event.create(name="Test", tournament_id=tournament.id)
        team = await Team.create(name="Some Team")
        await event.participants.add(team)

        tournament2 = await Event.filter(name="Test"
                                         ).values("name", "participants__name")
        self.assertEqual(tournament2[0], {
            "name": "Test",
            "participants__name": "Some Team"
        })

    async def test_values_list_related_m2m(self):
        tournament = await Tournament.create(name="New Tournament")
        event = await Event.create(name="Test", tournament_id=tournament.id)
        team = await Team.create(name="Some Team")
        await event.participants.add(team)

        tournament2 = await Event.filter(name="Test"
                                         ).values_list("name",
                                                       "participants__name")
        self.assertEqual(tournament2[0], ("Test", "Some Team"))

    async def test_values_related_fk_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                ValueError, 'Selecting relation "tournament" is not possible'):
            await Event.filter(name="Test").values("name", "tournament")

    async def test_values_list_related_fk_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                ValueError, 'Selecting relation "tournament" is not possible'):
            await Event.filter(name="Test").values_list("name", "tournament")

    async def test_values_related_rfk_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                ValueError, 'Selecting relation "events" is not possible'):
            await Tournament.filter(name="New Tournament"
                                    ).values("name", "events")

    async def test_values_list_related_rfk_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                ValueError, 'Selecting relation "events" is not possible'):
            await Tournament.filter(name="New Tournament"
                                    ).values_list("name", "events")

    async def test_values_related_m2m_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        event = await Event.create(name="Test", tournament_id=tournament.id)
        team = await Team.create(name="Some Team")
        await event.participants.add(team)

        with self.assertRaisesRegex(
                ValueError,
                'Selecting relation "participants" is not possible'):
            await Event.filter(name="Test").values("name", "participants")

    async def test_values_list_related_m2m_itself(self):
        tournament = await Tournament.create(name="New Tournament")
        event = await Event.create(name="Test", tournament_id=tournament.id)
        team = await Team.create(name="Some Team")
        await event.participants.add(team)

        with self.assertRaisesRegex(
                ValueError,
                'Selecting relation "participants" is not possible'):
            await Event.filter(name="Test").values_list("name", "participants")

    async def test_values_bad_key(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(FieldError,
                                    'Unknown field "neem" for model "Event"'):
            await Event.filter(name="Test").values("name", "neem")

    async def test_values_list_bad_key(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(FieldError,
                                    'Unknown field "neem" for model "Event"'):
            await Event.filter(name="Test").values_list("name", "neem")

    async def test_values_related_bad_key(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                FieldError, 'Unknown field "neem" for model "Tournament"'):
            await Event.filter(name="Test").values("name", "tournament__neem")

    async def test_values_list_related_bad_key(self):
        tournament = await Tournament.create(name="New Tournament")
        await Event.create(name="Test", tournament_id=tournament.id)

        with self.assertRaisesRegex(
                FieldError, 'Unknown field "neem" for model "Tournament"'):
            await Event.filter(name="Test"
                               ).values_list("name", "tournament__neem")

    @test.requireCapability(dialect="!mssql")
    async def test_values_list_annotations_length(self):
        await Tournament.create(name="Championship")
        await Tournament.create(name="Super Bowl")

        tournaments = await Tournament.annotate(name_length=Length("name")
                                                ).values_list(
                                                    "name", "name_length")
        self.assertListSortEqual(tournaments, [("Championship", 12),
                                               ("Super Bowl", 10)])

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_values_annotations_length(self):
        await Tournament.create(name="Championship")
        await Tournament.create(name="Super Bowl")

        tournaments = await Tournament.annotate(name_slength=Length("name")
                                                ).values(
                                                    "name", "name_slength")
        self.assertListSortEqual(
            tournaments,
            [
                {
                    "name": "Championship",
                    "name_slength": 12
                },
                {
                    "name": "Super Bowl",
                    "name_slength": 10
                },
            ],
            sorted_key="name",
        )

    async def test_values_list_annotations_trim(self):
        await Tournament.create(name="  x")
        await Tournament.create(name=" y ")

        tournaments = await Tournament.annotate(name_trim=Trim("name")
                                                ).values_list(
                                                    "name", "name_trim")
        self.assertListSortEqual(tournaments, [("  x", "x"), (" y ", "y")])

    async def test_values_annotations_trim(self):
        await Tournament.create(name="  x")
        await Tournament.create(name=" y ")

        tournaments = await Tournament.annotate(name_trim=Trim("name")
                                                ).values("name", "name_trim")
        self.assertListSortEqual(
            tournaments,
            [{
                "name": "  x",
                "name_trim": "x"
            }, {
                "name": " y ",
                "name_trim": "y"
            }],
            sorted_key="name",
        )
Esempio n. 9
0
class StraightFieldTests(test.TestCase):
    def setUp(self) -> None:
        self.model = StraightFields

    async def test_get_all(self):
        obj1 = await self.model.create(chars="aaa")
        self.assertIsNotNone(obj1.eyedee, str(dir(obj1)))
        obj2 = await self.model.create(chars="bbb")

        objs = await self.model.all()
        self.assertListSortEqual(objs, [obj1, obj2])

    async def test_get_by_pk(self):
        obj = await self.model.create(chars="aaa")
        obj1 = await self.model.get(eyedee=obj.eyedee)

        self.assertEqual(obj, obj1)

    async def test_get_by_chars(self):
        obj = await self.model.create(chars="aaa")
        obj1 = await self.model.get(chars="aaa")

        self.assertEqual(obj, obj1)

    async def test_get_fk_forward_fetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)

        obj2a = await self.model.get(eyedee=obj2.eyedee)
        await obj2a.fetch_related("fk")
        self.assertEqual(obj2, obj2a)
        self.assertEqual(obj1, obj2a.fk)

    async def test_get_fk_forward_prefetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)

        obj2a = await self.model.get(eyedee=obj2.eyedee).prefetch_related("fk")
        self.assertEqual(obj2, obj2a)
        self.assertEqual(obj1, obj2a.fk)

    async def test_get_fk_reverse_await(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)
        obj3 = await self.model.create(chars="ccc", fk=obj1)

        obj1a = await self.model.get(eyedee=obj1.eyedee)
        self.assertListSortEqual(await obj1a.fkrev, [obj2, obj3])

    async def test_get_fk_reverse_filter(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)
        obj3 = await self.model.create(chars="ccc", fk=obj1)

        objs = await self.model.filter(fk=obj1)
        self.assertListSortEqual(objs, [obj2, obj3])

    async def test_get_fk_reverse_async_for(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)
        obj3 = await self.model.create(chars="ccc", fk=obj1)

        obj1a = await self.model.get(eyedee=obj1.eyedee)
        objs = []
        async for obj in obj1a.fkrev:
            objs.append(obj)
        self.assertListSortEqual(objs, [obj2, obj3])

    async def test_get_fk_reverse_fetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)
        obj3 = await self.model.create(chars="ccc", fk=obj1)

        obj1a = await self.model.get(eyedee=obj1.eyedee)
        await obj1a.fetch_related("fkrev")
        self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3])

    async def test_get_fk_reverse_prefetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb", fk=obj1)
        obj3 = await self.model.create(chars="ccc", fk=obj1)

        obj1a = await self.model.get(eyedee=obj1.eyedee
                                     ).prefetch_related("fkrev")
        self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3])

    async def test_get_m2m_forward_await(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        obj2a = await self.model.get(eyedee=obj2.eyedee)
        self.assertEqual(await obj2a.rel_from, [obj1])

        obj1a = await self.model.get(eyedee=obj1.eyedee)
        self.assertEqual(await obj1a.rel_to, [obj2])

    async def test_get_m2m_reverse_await(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj2.rel_from.add(obj1)

        obj2a = await self.model.get(pk=obj2.eyedee)
        self.assertEqual(await obj2a.rel_from, [obj1])

        obj1a = await self.model.get(eyedee=obj1.pk)
        self.assertEqual(await obj1a.rel_to, [obj2])

    async def test_get_m2m_filter(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        rel_froms = await self.model.filter(rel_from=obj1)
        self.assertEqual(rel_froms, [obj2])

        rel_tos = await self.model.filter(rel_to=obj2)
        self.assertEqual(rel_tos, [obj1])

    async def test_get_m2m_forward_fetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        obj2a = await self.model.get(eyedee=obj2.eyedee)
        await obj2a.fetch_related("rel_from")
        self.assertEqual(list(obj2a.rel_from), [obj1])

    async def test_get_m2m_reverse_fetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        obj1a = await self.model.get(eyedee=obj1.eyedee)
        await obj1a.fetch_related("rel_to")
        self.assertEqual(list(obj1a.rel_to), [obj2])

    async def test_get_m2m_forward_prefetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        obj2a = await self.model.get(eyedee=obj2.eyedee
                                     ).prefetch_related("rel_from")
        self.assertEqual(list(obj2a.rel_from), [obj1])

    async def test_get_m2m_reverse_prefetch_related(self):
        obj1 = await self.model.create(chars="aaa")
        obj2 = await self.model.create(chars="bbb")
        await obj1.rel_to.add(obj2)

        obj1a = await self.model.get(eyedee=obj1.eyedee
                                     ).prefetch_related("rel_to")
        self.assertEqual(list(obj1a.rel_to), [obj2])

    async def test_values_reverse_relation(self):
        obj1 = await self.model.create(chars="aaa")
        await self.model.create(chars="bbb", fk=obj1)

        obj1a = await self.model.filter(chars="aaa").values("fkrev__chars")
        self.assertEqual(obj1a[0]["fkrev__chars"], "bbb")

    async def test_f_expression(self):
        obj1 = await self.model.create(chars="aaa")
        await self.model.filter(eyedee=obj1.eyedee).update(chars=F("blip"))
        obj2 = await self.model.get(eyedee=obj1.eyedee)
        self.assertEqual(obj2.chars, "BLIP")

    async def test_function(self):
        obj1 = await self.model.create(chars="  aaa ")
        await self.model.filter(eyedee=obj1.eyedee).update(chars=Trim("chars"))
        obj2 = await self.model.get(eyedee=obj1.eyedee)
        self.assertEqual(obj2.chars, "aaa")

    async def test_aggregation_with_filter(self):
        obj1 = await self.model.create(chars="aaa")
        await self.model.create(chars="bbb", fk=obj1)
        await self.model.create(chars="ccc", fk=obj1)

        obj = (await self.model.filter(chars="aaa").annotate(
            all=Count("fkrev", _filter=Q(chars="aaa")),
            one=Count("fkrev", _filter=Q(fkrev__chars="bbb")),
            no=Count("fkrev", _filter=Q(fkrev__chars="aaa")),
        ).first())

        self.assertEqual(obj.all, 2)
        self.assertEqual(obj.one, 1)
        self.assertEqual(obj.no, 0)

    async def test_filter_by_aggregation_field_coalesce(self):
        await self.model.create(chars="aaa", nullable="null")
        await self.model.create(chars="bbb")
        objs = await self.model.annotate(null=Coalesce("nullable", "null")
                                         ).filter(null="null")

        self.assertEqual(len(objs), 2)
        self.assertSetEqual({(o.chars, o.null)
                             for o in objs}, {("aaa", "null"),
                                              ("bbb", "null")})

    async def test_filter_by_aggregation_field_count(self):
        await self.model.create(chars="aaa")
        await self.model.create(chars="bbb")
        obj = await self.model.annotate(chars_count=Count("chars")
                                        ).filter(chars_count=1, chars="aaa")

        self.assertEqual(len(obj), 1)
        self.assertEqual(obj[0].chars, "aaa")

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_filter_by_aggregation_field_length(self):
        await self.model.create(chars="aaa")
        await self.model.create(chars="bbbbb")
        obj = await self.model.annotate(chars_length=Length("chars")
                                        ).filter(chars_length=3)

        self.assertEqual(len(obj), 1)
        self.assertEqual(obj[0].chars_length, 3)

    async def test_filter_by_aggregation_field_lower(self):
        await self.model.create(chars="AaA")
        obj = await self.model.annotate(chars_lower=Lower("chars")
                                        ).filter(chars_lower="aaa")

        self.assertEqual(len(obj), 1)
        self.assertEqual(obj[0].chars_lower, "aaa")

    async def test_filter_by_aggregation_field_trim(self):
        await self.model.create(chars="   aaa   ")
        obj = await self.model.annotate(chars_trim=Trim("chars")
                                        ).filter(chars_trim="aaa")

        self.assertEqual(len(obj), 1)
        self.assertEqual(obj[0].chars_trim, "aaa")

    async def test_filter_by_aggregation_field_upper(self):
        await self.model.create(chars="aAa")
        obj = await self.model.annotate(chars_upper=Upper("chars")
                                        ).filter(chars_upper="AAA")

        self.assertEqual(len(obj), 1)
        self.assertEqual(obj[0].chars_upper, "AAA")

    async def test_values_by_fk(self):
        obj1 = await self.model.create(chars="aaa")
        await self.model.create(chars="bbb", fk=obj1)

        obj = await self.model.filter(chars="bbb").values("fk__chars")
        self.assertEqual(obj, [{"fk__chars": "aaa"}])
Esempio n. 10
0
class TestUpdate(test.TestCase):
    async def test_update(self):
        await Tournament.create(name="1")
        await Tournament.create(name="3")
        rows_affected = await Tournament.all().update(name="2")
        self.assertEqual(rows_affected, 2)

        tournament = await Tournament.first()
        self.assertEqual(tournament.name, "2")

    async def test_bulk_update(self):
        objs = [
            await Tournament.create(name="1"), await
            Tournament.create(name="2")
        ]
        objs[0].name = "0"
        objs[1].name = "1"
        rows_affected = await Tournament.bulk_update(objs,
                                                     fields=["name"],
                                                     batch_size=100)
        self.assertEqual(rows_affected, 2)
        self.assertEqual((await Tournament.get(pk=objs[0].pk)).name, "0")
        self.assertEqual((await Tournament.get(pk=objs[1].pk)).name, "1")

    async def test_bulk_update_datetime(self):
        objs = [
            await DatetimeFields.create(
                datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)),
            await DatetimeFields.create(
                datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)),
        ]
        t0 = datetime(2021, 1, 2, tzinfo=pytz.utc)
        t1 = datetime(2021, 1, 3, tzinfo=pytz.utc)
        objs[0].datetime = t0
        objs[1].datetime = t1
        rows_affected = await DatetimeFields.bulk_update(objs,
                                                         fields=["datetime"])
        self.assertEqual(rows_affected, 2)
        self.assertEqual((await DatetimeFields.get(pk=objs[0].pk)).datetime,
                         t0)
        self.assertEqual((await DatetimeFields.get(pk=objs[1].pk)).datetime,
                         t1)

    async def test_bulk_update_pk_uuid(self):
        objs = [
            await UUIDFields.create(data=uuid.uuid4()),
            await UUIDFields.create(data=uuid.uuid4()),
        ]
        objs[0].data = uuid.uuid4()
        objs[1].data = uuid.uuid4()
        rows_affected = await UUIDFields.bulk_update(objs, fields=["data"])
        self.assertEqual(rows_affected, 2)
        self.assertEqual((await UUIDFields.get(pk=objs[0].pk)).data,
                         objs[0].data)
        self.assertEqual((await UUIDFields.get(pk=objs[1].pk)).data,
                         objs[1].data)

    async def test_bulk_update_json_value(self):
        objs = [
            await JSONFields.create(data={}),
            await JSONFields.create(data={}),
        ]
        objs[0].data = [0]
        objs[1].data = {"a": 1}
        rows_affected = await JSONFields.bulk_update(objs, fields=["data"])
        self.assertEqual(rows_affected, 2)
        self.assertEqual((await JSONFields.get(pk=objs[0].pk)).data,
                         objs[0].data)
        self.assertEqual((await JSONFields.get(pk=objs[1].pk)).data,
                         objs[1].data)

    @test.requireCapability(dialect=NotEQ("mssql"))
    async def test_bulk_update_smallint_none(self):
        objs = [
            await SmallIntFields.create(smallintnum=1, smallintnum_null=1),
            await SmallIntFields.create(smallintnum=2, smallintnum_null=2),
        ]
        objs[0].smallintnum_null = None
        objs[1].smallintnum_null = None
        rows_affected = await SmallIntFields.bulk_update(
            objs, fields=["smallintnum_null"])
        self.assertEqual(rows_affected, 2)
        self.assertEqual(
            (await SmallIntFields.get(pk=objs[0].pk)).smallintnum_null, None)
        self.assertEqual(
            (await SmallIntFields.get(pk=objs[1].pk)).smallintnum_null, None)

    async def test_update_auto_now(self):
        obj = await DefaultUpdate.create()

        now = datetime.now()
        updated_at = now - timedelta(days=1)
        await DefaultUpdate.filter(pk=obj.pk).update(updated_at=updated_at)

        obj1 = await DefaultUpdate.get(pk=obj.pk)
        self.assertEqual(obj1.updated_at.date(), updated_at.date())

    async def test_update_relation(self):
        tournament_first = await Tournament.create(name="1")
        tournament_second = await Tournament.create(name="2")

        await Event.create(name="1", tournament=tournament_first)
        await Event.all().update(tournament=tournament_second)
        event = await Event.first()
        self.assertEqual(event.tournament_id, tournament_second.id)

    @test.requireCapability(dialect=In("mysql", "sqlite"))
    async def test_update_with_custom_function(self):
        class JsonSet(Function):
            def __init__(self, field: F, expression: str, value: Any):
                super().__init__("JSON_SET", field, expression, value)

        json = await JSONFields.create(data={})
        self.assertEqual(json.data_default, {"a": 1})

        json.data_default = JsonSet(F("data_default"), "$.a", 2)
        await json.save()

        json_update = await JSONFields.get(pk=json.pk)
        self.assertEqual(json_update.data_default, {"a": 2})

        await JSONFields.filter(pk=json.pk).update(
            data_default=JsonSet(F("data_default"), "$.a", 3))
        json_update = await JSONFields.get(pk=json.pk)
        self.assertEqual(json_update.data_default, {"a": 3})

    async def test_refresh_from_db(self):
        int_field = await IntFields.create(intnum=1, intnum_null=2)
        int_field_in_db = await IntFields.get(pk=int_field.pk)
        int_field_in_db.intnum = F("intnum") + 1
        await int_field_in_db.save(update_fields=["intnum"])
        self.assertIsNot(int_field_in_db.intnum, 2)
        self.assertIs(int_field_in_db.intnum_null, 2)

        await int_field_in_db.refresh_from_db(fields=["intnum"])
        self.assertIs(int_field_in_db.intnum, 2)
        self.assertIs(int_field_in_db.intnum_null, 2)

        int_field_in_db.intnum = F("intnum") + 1
        await int_field_in_db.save()
        self.assertIsNot(int_field_in_db.intnum, 3)
        self.assertIs(int_field_in_db.intnum_null, 2)

        await int_field_in_db.refresh_from_db()
        self.assertIs(int_field_in_db.intnum, 3)
        self.assertIs(int_field_in_db.intnum_null, 2)

    @test.requireCapability(support_update_limit_order_by=True)
    async def test_update_with_limit_ordering(self):
        await Tournament.create(name="1")
        t2 = await Tournament.create(name="1")
        await Tournament.filter(name="1"
                                ).limit(1).order_by("-id").update(name="2")
        self.assertIs((await Tournament.get(pk=t2.pk)).name, "2")
        self.assertEqual(await Tournament.filter(name="1").count(), 1)