class TestModelCreate(test.TestCase): async def test_save_generated(self): mdl = await Tournament.create(name="Test") mdl2 = await Tournament.get(id=mdl.id) self.assertEqual(mdl, mdl2) async def test_save_non_generated(self): mdl = await UUIDFkRelatedNullModel.create(name="Test") mdl2 = await UUIDFkRelatedNullModel.get(id=mdl.id) self.assertEqual(mdl, mdl2) @test.requireCapability(dialect=NotEQ("mssql")) async def test_save_generated_custom_id(self): cid = 12345 mdl = await Tournament.create(id=cid, name="Test") self.assertEqual(mdl.id, cid) mdl2 = await Tournament.get(id=cid) self.assertEqual(mdl, mdl2) async def test_save_non_generated_custom_id(self): cid = uuid4() mdl = await UUIDFkRelatedNullModel.create(id=cid, name="Test") self.assertEqual(mdl.id, cid) mdl2 = await UUIDFkRelatedNullModel.get(id=cid) self.assertEqual(mdl, mdl2) @test.requireCapability(dialect=NotEQ("mssql")) async def test_save_generated_duplicate_custom_id(self): cid = 12345 await Tournament.create(id=cid, name="TestOriginal") with self.assertRaises(IntegrityError): await Tournament.create(id=cid, name="Test") async def test_save_non_generated_duplicate_custom_id(self): cid = uuid4() await UUIDFkRelatedNullModel.create(id=cid, name="TestOriginal") with self.assertRaises(IntegrityError): await UUIDFkRelatedNullModel.create(id=cid, name="Test") async def test_clone_pk_required_error(self): mdl = await RequiredPKModel.create(id="A", name="name_a") with self.assertRaises(ParamsError): mdl.clone() async def test_clone_pk_required(self): mdl = await RequiredPKModel.create(id="A", name="name_a") mdl2 = mdl.clone(pk="B") await mdl2.save() mdls = list(await RequiredPKModel.all()) self.assertEqual(len(mdls), 2) async def test_implicit_clone_pk_required_none(self): mdl = await RequiredPKModel.create(id="A", name="name_a") mdl.pk = None with self.assertRaises(ValidationError): await mdl.save()
class TestOrderByNested(test.TestCase): @test.requireCapability(dialect=NotEQ("oracle")) async def test_basic(self): await Event.create( name="Event 1", tournament=await Tournament.create(name="Tournament 1", desc="B") ) await Event.create( name="Event 2", tournament=await Tournament.create(name="Tournament 2", desc="A") ) self.assertEqual( await Event.all().order_by("-name").values("name"), [{"name": "Event 2"}, {"name": "Event 1"}], ) self.assertEqual( await Event.all().prefetch_related("tournament").values("tournament__desc"), [{"tournament__desc": "B"}, {"tournament__desc": "A"}], ) self.assertEqual( await Event.all() .prefetch_related("tournament") .order_by("tournament__desc") .values("tournament__desc"), [{"tournament__desc": "A"}, {"tournament__desc": "B"}], )
class TestConcurrencyIsolated(test.IsolatedTestCase): async def test_concurrency_read(self): await Tournament.create(name="Test") tour1 = await Tournament.first() all_read = await asyncio.gather( *[Tournament.first() for _ in range(100)]) self.assertEqual(all_read, [tour1 for _ in range(100)]) async def test_concurrency_create(self): all_write = await asyncio.gather( *[Tournament.create(name="Test") for _ in range(100)]) all_read = await Tournament.all() self.assertEqual(set(all_write), set(all_read)) async def create_trans_concurrent(self): async with in_transaction(): await asyncio.gather( *[Tournament.create(name="Test") for _ in range(100)]) async def test_nonconcurrent_get_or_create(self): unas = [await UniqueName.get_or_create(name="c") for _ in range(10)] una_created = [una[1] for una in unas if una[1] is True] self.assertEqual(len(una_created), 1) for una in unas: self.assertEqual(una[0], unas[0][0]) @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") @test.requireCapability(dialect=NotEQ("mssql")) async def test_concurrent_get_or_create(self): unas = await asyncio.gather( *[UniqueName.get_or_create(name="d") for _ in range(10)]) una_created = [una[1] for una in unas if una[1] is True] self.assertEqual(len(una_created), 1) for una in unas: self.assertEqual(una[0], unas[0][0]) @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") @test.requireCapability(supports_transactions=True) async def test_concurrency_transactions_concurrent(self): await asyncio.gather( *[self.create_trans_concurrent() for _ in range(10)]) count = await Tournament.all().count() self.assertEqual(count, 1000) async def create_trans(self): async with in_transaction(): await Tournament.create(name="Test") @test.skipIf(sys.version_info < (3, 7), "aiocontextvars backport not handling this well") @test.requireCapability(supports_transactions=True) async def test_concurrency_transactions(self): await asyncio.gather(*[self.create_trans() for _ in range(100)]) count = await Tournament.all().count() self.assertEqual(count, 100)
class TestExplain(test.TestCase): @test.requireCapability(dialect=NotEQ("mssql")) async def test_explain(self): # NOTE: we do not provide any guarantee on the format of the value # returned by `.explain()`, as it heavily depends on the database. # This test merely checks that one is able to run `.explain()` # without errors for each backend. plan = await Tournament.all().explain() # This should have returned *some* information. self.assertGreater(len(str(plan)), 20)
class TestFuzz(test.TestCase): @test.requireCapability(dialect=NotEQ("mssql")) async def test_char_fuzz(self): for char in DODGY_STRINGS: # print(repr(char)) if "\x00" in char and self._db.capabilities.dialect in [ "postgres" ]: # PostgreSQL doesn't support null values as text. Ever. So skip these. continue # Create obj1 = await CharFields.create(char=char) # Get-by-pk, and confirm that reading is correct obj2 = await CharFields.get(pk=obj1.pk) self.assertEqual(char, obj2.char) # Update data using a queryset, confirm that update is correct await CharFields.filter(pk=obj1.pk).update(char="a") await CharFields.filter(pk=obj1.pk).update(char=char) obj3 = await CharFields.get(pk=obj1.pk) self.assertEqual(char, obj3.char) # Filter by value in queryset, and confirm that it fetched the right one obj4 = await CharFields.get(pk=obj1.pk, char=char) self.assertEqual(obj1.pk, obj4.pk) self.assertEqual(char, obj4.char) # LIKE statements are not strict, so require all of these to match obj5 = await CharFields.get( pk=obj1.pk, char__startswith=char, char__endswith=char, char__contains=char, char__istartswith=char, char__iendswith=char, char__icontains=char, ) self.assertEqual(obj1.pk, obj5.pk) self.assertEqual(char, obj5.char)
class TestBulk(test.TruncationTestCase): async def test_bulk_create(self): await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) all_ = await UniqueName.all().values("id", "name") inc = all_[0]["id"] self.assertListSortEqual(all_, [{ "id": val + inc, "name": None } for val in range(1000)], sorted_key="id") @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_update_fields(self): await UniqueName.bulk_create([UniqueName(name="name")]) await UniqueName.bulk_create( [UniqueName(name="name", optional="optional")], update_fields=["optional"], on_conflict=["name"], ) all_ = await UniqueName.all().values("name", "optional") self.assertListSortEqual(all_, [{ "name": "name", "optional": "optional" }]) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_more_that_one_update_fields(self): await UniqueName.bulk_create([UniqueName(name="name")]) await UniqueName.bulk_create( [ UniqueName(name="name", optional="optional", other_optional="other_optional") ], update_fields=["optional", "other_optional"], on_conflict=["name"], ) all_ = await UniqueName.all().values("name", "optional", "other_optional") self.assertListSortEqual(all_, [{ "name": "name", "optional": "optional", "other_optional": "other_optional" }]) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_with_batch_size(self): await UniqueName.bulk_create( [UniqueName(id=id_ + 1) for id_ in range(1000)], batch_size=100) all_ = await UniqueName.all().values("id", "name") self.assertListSortEqual(all_, [{ "id": val + 1, "name": None } for val in range(1000)], sorted_key="id") @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_with_specified(self): await UniqueName.bulk_create( [UniqueName(id=id_) for id_ in range(1000, 2000)]) all_ = await UniqueName.all().values("id", "name") self.assertListSortEqual(all_, [{ "id": id_, "name": None } for id_ in range(1000, 2000)], sorted_key="id") @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_mix_specified(self): await UniqueName.bulk_create( [UniqueName(id=id_) for id_ in range(10000, 11000)] + [UniqueName() for _ in range(1000)]) all_ = await UniqueName.all().values("id", "name") self.assertEqual(len(all_), 2000) self.assertListSortEqual(all_[:1000], [{ "id": id_, "name": None } for id_ in range(10000, 11000)], sorted_key="id") inc = all_[1000]["id"] self.assertListSortEqual(all_[1000:], [{ "id": val + inc, "name": None } for val in range(1000)], sorted_key="id") async def test_bulk_create_uuidpk(self): await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) res = await UUIDPkModel.all().values_list("id", flat=True) self.assertEqual(len(res), 1000) self.assertIsInstance(res[0], UUID) @test.requireCapability(supports_transactions=True) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_in_transaction(self): async with in_transaction(): await UniqueName.bulk_create([UniqueName() for _ in range(1000)]) all_ = await UniqueName.all().values("id", "name") inc = all_[0]["id"] self.assertEqual(all_, [{ "id": val + inc, "name": None } for val in range(1000)]) @test.requireCapability(supports_transactions=True) async def test_bulk_create_uuidpk_in_transaction(self): async with in_transaction(): await UUIDPkModel.bulk_create([UUIDPkModel() for _ in range(1000)]) res = await UUIDPkModel.all().values_list("id", flat=True) self.assertEqual(len(res), 1000) self.assertIsInstance(res[0], UUID) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_fail(self): with self.assertRaises(IntegrityError): await UniqueName.bulk_create( [UniqueName(name=str(i)) for i in range(10)] + [UniqueName(name=str(i)) for i in range(10)]) async def test_bulk_create_uuidpk_fail(self): val = uuid4() with self.assertRaises(IntegrityError): await UUIDPkModel.bulk_create( [UUIDPkModel(id=val) for _ in range(10)]) @test.requireCapability(supports_transactions=True, dialect=NotEQ("mssql")) async def test_bulk_create_in_transaction_fail(self): with self.assertRaises(IntegrityError): async with in_transaction(): await UniqueName.bulk_create( [UniqueName(name=str(i)) for i in range(10)] + [UniqueName(name=str(i)) for i in range(10)]) @test.requireCapability(supports_transactions=True) async def test_bulk_create_uuidpk_in_transaction_fail(self): val = uuid4() with self.assertRaises(IntegrityError): async with in_transaction(): await UUIDPkModel.bulk_create( [UUIDPkModel(id=val) for _ in range(10)]) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_create_ignore_conflicts(self): name1 = UniqueName(name="name1") name2 = UniqueName(name="name2") await UniqueName.bulk_create([name1, name2]) await UniqueName.bulk_create([name1, name2], ignore_conflicts=True) with self.assertRaises(IntegrityError): await UniqueName.bulk_create([name1, name2])
class TestQueryset(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() # Build large dataset self.intfields = [ await IntFields.create(intnum=val) for val in range(10, 100, 3) ] self.db = connections.get("models") async def test_all_count(self): self.assertEqual(await IntFields.all().count(), 30) self.assertEqual(await IntFields.filter(intnum_null=80).count(), 0) async def test_exists(self): ret = await IntFields.filter(intnum=0).exists() self.assertFalse(ret) ret = await IntFields.filter(intnum=10).exists() self.assertTrue(ret) ret = await IntFields.filter(intnum__gt=10).exists() self.assertTrue(ret) ret = await IntFields.filter(intnum__lt=10).exists() self.assertFalse(ret) async def test_limit_count(self): self.assertEqual(await IntFields.all().limit(10).count(), 10) async def test_limit_negative(self): with self.assertRaisesRegex(ParamsError, "Limit should be non-negative number"): await IntFields.all().limit(-10) async def test_offset_count(self): self.assertEqual(await IntFields.all().offset(10).count(), 20) async def test_offset_negative(self): with self.assertRaisesRegex(ParamsError, "Offset should be non-negative number"): await IntFields.all().offset(-10) async def test_join_count(self): tour = await Tournament.create(name="moo") await MinRelation.create(tournament=tour) self.assertEqual(await MinRelation.all().count(), 1) self.assertEqual( await MinRelation.filter(tournament__id=tour.id).count(), 1) async def test_modify_dataset(self): # Modify dataset rows_affected = await IntFields.filter(intnum__gte=70 ).update(intnum_null=80) self.assertEqual(rows_affected, 10) self.assertEqual(await IntFields.filter(intnum_null=80).count(), 10) self.assertEqual( await IntFields.filter(intnum_null__isnull=True).count(), 20) await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) self.assertEqual(await IntFields.filter(intnum_null=None).count(), 0) self.assertEqual(await IntFields.filter(intnum_null=-1).count(), 20) async def test_distinct(self): # Test distinct await IntFields.filter(intnum__gte=70).update(intnum_null=80) await IntFields.filter(intnum_null__isnull=True).update(intnum_null=-1) self.assertEqual( await IntFields.all().order_by("intnum_null").distinct().values_list( "intnum_null", flat=True), [-1, 80], ) self.assertEqual( await IntFields.all().order_by("intnum_null").distinct().values( "intnum_null"), [{ "intnum_null": -1 }, { "intnum_null": 80 }], ) async def test_limit_offset_values_list(self): # Test limit/offset/ordering values_list self.assertEqual( await IntFields.all().order_by("intnum").limit(10).values_list( "intnum", flat=True), [10, 13, 16, 19, 22, 25, 28, 31, 34, 37], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(10).offset( 10).values_list("intnum", flat=True), [40, 43, 46, 49, 52, 55, 58, 61, 64, 67], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(10).offset( 20).values_list("intnum", flat=True), [70, 73, 76, 79, 82, 85, 88, 91, 94, 97], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(10).offset( 30).values_list("intnum", flat=True), [], ) self.assertEqual( await IntFields.all().order_by("-intnum").limit(10).values_list( "intnum", flat=True), [97, 94, 91, 88, 85, 82, 79, 76, 73, 70], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(10).filter(intnum__gte=40 ).values_list( "intnum", flat=True), [40, 43, 46, 49, 52, 55, 58, 61, 64, 67], ) async def test_limit_offset_values(self): # Test limit/offset/ordering values self.assertEqual( await IntFields.all().order_by("intnum").limit(5).values("intnum"), [{ "intnum": 10 }, { "intnum": 13 }, { "intnum": 16 }, { "intnum": 19 }, { "intnum": 22 }], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(5).offset( 10).values("intnum"), [{ "intnum": 40 }, { "intnum": 43 }, { "intnum": 46 }, { "intnum": 49 }, { "intnum": 52 }], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(5).offset( 30).values("intnum"), []) self.assertEqual( await IntFields.all().order_by("-intnum").limit(5).values("intnum"), [{ "intnum": 97 }, { "intnum": 94 }, { "intnum": 91 }, { "intnum": 88 }, { "intnum": 85 }], ) self.assertEqual( await IntFields.all().order_by("intnum").limit(5).filter( intnum__gte=40).values("intnum"), [{ "intnum": 40 }, { "intnum": 43 }, { "intnum": 46 }, { "intnum": 49 }, { "intnum": 52 }], ) async def test_in_bulk(self): id_list = [ item.pk for item in await IntFields.all().only("id").limit(2) ] ret = await IntFields.in_bulk(id_list=id_list) self.assertEqual(list(ret.keys()), id_list) async def test_first(self): # Test first self.assertEqual( (await IntFields.all().order_by("intnum").filter(intnum__gte=40 ).first()).intnum, 40) self.assertEqual( (await IntFields.all().order_by("intnum").filter( intnum__gte=40).first().values())["intnum"], 40, ) self.assertEqual( (await IntFields.all().order_by("intnum").filter( intnum__gte=40).first().values_list())[1], 40, ) self.assertEqual( await IntFields.all().order_by("intnum").filter(intnum__gte=400).first(), None) self.assertEqual( await IntFields.all().order_by("intnum").filter(intnum__gte=400 ).first().values(), None) self.assertEqual( await IntFields.all().order_by("intnum").filter(intnum__gte=400 ).first().values_list(), None, ) async def test_get_or_none(self): self.assertEqual((await IntFields.all().get_or_none(intnum=40)).intnum, 40) self.assertEqual( (await IntFields.all().get_or_none(intnum=40).values())["intnum"], 40) self.assertEqual( (await IntFields.all().get_or_none(intnum=40).values_list())[1], 40) self.assertEqual( await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400), None) self.assertEqual( await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400 ).values(), None) self.assertEqual( await IntFields.all().order_by("intnum").get_or_none(intnum__gte=400 ).values_list(), None, ) with self.assertRaises(MultipleObjectsReturned): await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40 ) with self.assertRaises(MultipleObjectsReturned): await IntFields.all().order_by("intnum").get_or_none(intnum__gte=40 ).values() with self.assertRaises(MultipleObjectsReturned): await IntFields.all().order_by("intnum").get_or_none( intnum__gte=40).values_list() async def test_get(self): await IntFields.filter(intnum__gte=70).update(intnum_null=80) # Test get self.assertEqual((await IntFields.all().get(intnum=40)).intnum, 40) self.assertEqual( (await IntFields.all().get(intnum=40).values())["intnum"], 40) self.assertEqual((await IntFields.all().get(intnum=40).values_list())[1], 40) self.assertEqual( (await IntFields.all().all().all().all().all().get(intnum=40)).intnum, 40) self.assertEqual( (await IntFields.all().all().all().all().all().get(intnum=40 ).values())["intnum"], 40) self.assertEqual( (await IntFields.all().all().all().all().all().get(intnum=40 ).values_list())[1], 40) self.assertEqual((await IntFields.get(intnum=40)).intnum, 40) self.assertEqual((await IntFields.get(intnum=40).values())["intnum"], 40) self.assertEqual((await IntFields.get(intnum=40).values_list())[1], 40) with self.assertRaises(DoesNotExist): await IntFields.all().get(intnum=41) with self.assertRaises(DoesNotExist): await IntFields.all().get(intnum=41).values() with self.assertRaises(DoesNotExist): await IntFields.all().get(intnum=41).values_list() with self.assertRaises(DoesNotExist): await IntFields.get(intnum=41) with self.assertRaises(DoesNotExist): await IntFields.get(intnum=41).values() with self.assertRaises(DoesNotExist): await IntFields.get(intnum=41).values_list() with self.assertRaises(MultipleObjectsReturned): await IntFields.all().get(intnum_null=80) with self.assertRaises(MultipleObjectsReturned): await IntFields.all().get(intnum_null=80).values() with self.assertRaises(MultipleObjectsReturned): await IntFields.all().get(intnum_null=80).values_list() with self.assertRaises(MultipleObjectsReturned): await IntFields.get(intnum_null=80) with self.assertRaises(MultipleObjectsReturned): await IntFields.get(intnum_null=80).values() with self.assertRaises(MultipleObjectsReturned): await IntFields.get(intnum_null=80).values_list() async def test_delete(self): # Test delete await (await IntFields.get(intnum=40)).delete() with self.assertRaises(DoesNotExist): await IntFields.get(intnum=40) self.assertEqual(await IntFields.all().count(), 29) rows_affected = (await IntFields.all().order_by("intnum").limit(10).filter( intnum__gte=70).delete()) self.assertEqual(rows_affected, 10) self.assertEqual(await IntFields.all().count(), 19) @test.requireCapability(support_update_limit_order_by=True) async def test_delete_limit(self): await IntFields.all().limit(1).delete() self.assertEqual(await IntFields.all().count(), 29) @test.requireCapability(support_update_limit_order_by=True) async def test_delete_limit_order_by(self): await IntFields.all().limit(1).order_by("-id").delete() self.assertEqual(await IntFields.all().count(), 29) with self.assertRaises(DoesNotExist): await IntFields.get(intnum=97) async def test_async_iter(self): counter = 0 async for _ in IntFields.all(): counter += 1 self.assertEqual(await IntFields.all().count(), counter) async def test_update_basic(self): obj0 = await IntFields.create(intnum=2147483647) await IntFields.filter(id=obj0.id).update(intnum=2147483646) obj = await IntFields.get(id=obj0.id) self.assertEqual(obj.intnum, 2147483646) self.assertEqual(obj.intnum_null, None) async def test_update_f_expression(self): obj0 = await IntFields.create(intnum=2147483647) await IntFields.filter(id=obj0.id).update(intnum=F("intnum") - 1) obj = await IntFields.get(id=obj0.id) self.assertEqual(obj.intnum, 2147483646) async def test_update_badparam(self): obj0 = await IntFields.create(intnum=2147483647) with self.assertRaisesRegex(FieldError, "Unknown keyword argument"): await IntFields.filter(id=obj0.id).update(badparam=1) async def test_update_pk(self): obj0 = await IntFields.create(intnum=2147483647) with self.assertRaisesRegex(IntegrityError, "is PK and can not be updated"): await IntFields.filter(id=obj0.id).update(id=1) async def test_update_virtual(self): tour = await Tournament.create(name="moo") obj0 = await MinRelation.create(tournament=tour) with self.assertRaisesRegex(FieldError, "is virtual and can not be updated"): await MinRelation.filter(id=obj0.id).update(participants=[]) async def test_bad_ordering(self): with self.assertRaisesRegex( FieldError, "Unknown field moo1fip for model IntFields"): await IntFields.all().order_by("moo1fip") async def test_duplicate_values(self): with self.assertRaisesRegex(FieldError, "Duplicate key intnum"): await IntFields.all().values("intnum", "intnum") async def test_duplicate_values_list(self): await IntFields.all().values_list("intnum", "intnum") async def test_duplicate_values_kw(self): with self.assertRaisesRegex(FieldError, "Duplicate key intnum"): await IntFields.all().values("intnum", intnum="intnum_null") async def test_duplicate_values_kw_badmap(self): with self.assertRaisesRegex( FieldError, 'Unknown field "intnum2" for model "IntFields"'): await IntFields.all().values(intnum="intnum2") async def test_bad_values(self): with self.assertRaisesRegex( FieldError, 'Unknown field "int2num" for model "IntFields"'): await IntFields.all().values("int2num") async def test_bad_values_list(self): with self.assertRaisesRegex( FieldError, 'Unknown field "int2num" for model "IntFields"'): await IntFields.all().values_list("int2num") async def test_many_flat_values_list(self): with self.assertRaisesRegex( TypeError, "You can flat value_list only if contains one field"): await IntFields.all().values_list("intnum", "intnum_null", flat=True) async def test_all_flat_values_list(self): with self.assertRaisesRegex( TypeError, "You can flat value_list only if contains one field"): await IntFields.all().values_list(flat=True) async def test_all_values_list(self): data = await IntFields.all().order_by("id").values_list() self.assertEqual(data[2], (self.intfields[2].id, 16, None)) async def test_all_values(self): data = await IntFields.all().order_by("id").values() self.assertEqual(data[2], { "id": self.intfields[2].id, "intnum": 16, "intnum_null": None }) async def test_order_by_bad_value(self): with self.assertRaisesRegex(FieldError, "Unknown field badid for model IntFields"): await IntFields.all().order_by("badid").values_list() async def test_annotate_order_expression(self): data = (await IntFields.annotate(idp=F("id") + 1 ).order_by("-idp").first().values_list( "id", "idp")) self.assertEqual(data[0] + 1, data[1]) async def test_annotate_expression_filter(self): count = await IntFields.annotate(intnum=F("intnum") + 1 ).filter(intnum__gt=30).count() self.assertEqual(count, 23) async def test_get_raw_sql(self): sql = IntFields.all().sql() self.assertRegex(sql, r"^SELECT.+FROM.+") @test.requireCapability(support_index_hint=True) async def test_force_index(self): sql = IntFields.filter(pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql, "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", ) sql_again = IntFields.filter( pk=1).only("id").force_index("index_name").sql() self.assertEqual( sql_again, "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", ) @test.requireCapability(support_index_hint=True) async def test_force_index_avaiable_in_more_query(self): sql_ValuesQuery = IntFields.filter( pk=1).force_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, "SELECT `id` `id` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", ) sql_ValuesListQuery = (IntFields.filter( pk=1).force_index("index_name").values_list("id").sql()) self.assertEqual( sql_ValuesListQuery, "SELECT `id` `0` FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", ) sql_CountQuery = IntFields.filter( pk=1).force_index("index_name").count().sql() self.assertEqual( sql_CountQuery, "SELECT COUNT(*) FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1", ) sql_ExistsQuery = IntFields.filter( pk=1).force_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, "SELECT 1 FROM `intfields` FORCE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", ) @test.requireCapability(support_index_hint=True) async def test_use_index(self): sql = IntFields.filter(pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql, "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", ) sql_again = IntFields.filter( pk=1).only("id").use_index("index_name").sql() self.assertEqual( sql_again, "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", ) @test.requireCapability(support_index_hint=True) async def test_use_index_avaiable_in_more_query(self): sql_ValuesQuery = IntFields.filter( pk=1).use_index("index_name").values("id").sql() self.assertEqual( sql_ValuesQuery, "SELECT `id` `id` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", ) sql_ValuesListQuery = IntFields.filter( pk=1).use_index("index_name").values_list("id").sql() self.assertEqual( sql_ValuesListQuery, "SELECT `id` `0` FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", ) sql_CountQuery = IntFields.filter( pk=1).use_index("index_name").count().sql() self.assertEqual( sql_CountQuery, "SELECT COUNT(*) FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1", ) sql_ExistsQuery = IntFields.filter( pk=1).use_index("index_name").exists().sql() self.assertEqual( sql_ExistsQuery, "SELECT 1 FROM `intfields` USE INDEX (`index_name`) WHERE `id`=1 LIMIT 1", ) @test.requireCapability(support_for_update=True) async def test_select_for_update(self): sql1 = IntFields.filter(pk=1).only("id").select_for_update().sql() sql2 = IntFields.filter(pk=1).only("id").select_for_update( nowait=True).sql() sql3 = IntFields.filter(pk=1).only("id").select_for_update( skip_locked=True).sql() sql4 = IntFields.filter(pk=1).only("id").select_for_update( of=("intfields", )).sql() dialect = self.db.schema_generator.DIALECT if dialect == "postgres": self.assertEqual( sql1, 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE', ) self.assertEqual( sql2, 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE NOWAIT', ) self.assertEqual( sql3, 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE SKIP LOCKED', ) self.assertEqual( sql4, 'SELECT "id" "id" FROM "intfields" WHERE "id"=1 FOR UPDATE OF "intfields"', ) elif dialect == "mysql": self.assertEqual( sql1, "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE", ) self.assertEqual( sql2, "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE NOWAIT", ) self.assertEqual( sql3, "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE SKIP LOCKED", ) self.assertEqual( sql4, "SELECT `id` `id` FROM `intfields` WHERE `id`=1 FOR UPDATE OF `intfields`", ) async def test_select_related(self): tournament = await Tournament.create(name="1") reporter = await Reporter.create(name="Reporter") event = await Event.create(name="1", tournament=tournament, reporter=reporter) event = await Event.all().select_related("tournament", "reporter").get(pk=event.pk) self.assertEqual(event.tournament.pk, tournament.pk) self.assertEqual(event.reporter.pk, reporter.pk) async def test_select_related_with_two_same_models(self): parent_node = await Node.create(name="1") child_node = await Node.create(name="2") tree = await Tree.create(parent=parent_node, child=child_node) tree = await Tree.all().select_related("parent", "child").get(pk=tree.pk) self.assertEqual(tree.parent.pk, parent_node.pk) self.assertEqual(tree.parent.name, parent_node.name) self.assertEqual(tree.child.pk, child_node.pk) self.assertEqual(tree.child.name, child_node.name) @test.requireCapability(dialect="postgres") async def test_postgres_search(self): name = "hello world" await Tournament.create(name=name) ret = await Tournament.filter(name__search="hello").first() self.assertEqual(ret.name, name) async def test_subquery_select(self): t1 = await Tournament.create(name="1") ret = (await Tournament.filter(pk=t1.pk).annotate( ids=Subquery(Tournament.filter(pk=t1.pk).values("id")) ).values("ids", "id")) self.assertEqual(ret, [{"id": t1.pk, "ids": t1.pk}]) async def test_subquery_filter(self): t1 = await Tournament.create(name="1") ret = await Tournament.filter( pk=Subquery(Tournament.filter(pk=t1.pk).values("id"))).first() self.assertEqual(ret, t1) async def test_raw_sql_count(self): t1 = await Tournament.create(name="1") ret = await Tournament.filter(pk=t1.pk ).annotate(count=RawSQL("count(*)") ).values("count") self.assertEqual(ret, [{"count": 1}]) @test.requireCapability(dialect=NotEQ("mssql")) async def test_raw_sql_select(self): t1 = await Tournament.create(id=1, name="1") ret = (await Tournament.filter(pk=t1.pk).annotate(idp=RawSQL("id + 1") ).filter(idp=2 ).values("idp")) self.assertEqual(ret, [{"idp": 2}]) async def test_raw_sql_filter(self): ret = await Tournament.filter(pk=RawSQL("id + 1")) self.assertEqual(ret, []) async def test_annotation_field_priorior_to_model_field(self): # Sometimes, field name in annotates also exist in model field sets # and may need lift the former's priority in select query construction. t1 = await Tournament.create(name="1") ret = await Tournament.filter(pk=t1.pk).annotate(id=RawSQL("id + 1") ).values("id") self.assertEqual(ret, [{"id": t1.pk + 1}])
class TestValues(test.TestCase): async def test_values_related_fk(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) event2 = await Event.filter(name="Test" ).values("name", "tournament__name") self.assertEqual(event2[0], { "name": "Test", "tournament__name": "New Tournament" }) async def test_values_list_related_fk(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) event2 = await Event.filter(name="Test" ).values_list("name", "tournament__name") self.assertEqual(event2[0], ("Test", "New Tournament")) async def test_values_related_rfk(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) tournament2 = await Tournament.filter(name="New Tournament" ).values("name", "events__name") self.assertEqual(tournament2[0], { "name": "New Tournament", "events__name": "Test" }) async def test_values_list_related_rfk(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) tournament2 = await Tournament.filter(name="New Tournament" ).values_list( "name", "events__name") self.assertEqual(tournament2[0], ("New Tournament", "Test")) async def test_values_related_m2m(self): tournament = await Tournament.create(name="New Tournament") event = await Event.create(name="Test", tournament_id=tournament.id) team = await Team.create(name="Some Team") await event.participants.add(team) tournament2 = await Event.filter(name="Test" ).values("name", "participants__name") self.assertEqual(tournament2[0], { "name": "Test", "participants__name": "Some Team" }) async def test_values_list_related_m2m(self): tournament = await Tournament.create(name="New Tournament") event = await Event.create(name="Test", tournament_id=tournament.id) team = await Team.create(name="Some Team") await event.participants.add(team) tournament2 = await Event.filter(name="Test" ).values_list("name", "participants__name") self.assertEqual(tournament2[0], ("Test", "Some Team")) async def test_values_related_fk_itself(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( ValueError, 'Selecting relation "tournament" is not possible'): await Event.filter(name="Test").values("name", "tournament") async def test_values_list_related_fk_itself(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( ValueError, 'Selecting relation "tournament" is not possible'): await Event.filter(name="Test").values_list("name", "tournament") async def test_values_related_rfk_itself(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( ValueError, 'Selecting relation "events" is not possible'): await Tournament.filter(name="New Tournament" ).values("name", "events") async def test_values_list_related_rfk_itself(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( ValueError, 'Selecting relation "events" is not possible'): await Tournament.filter(name="New Tournament" ).values_list("name", "events") async def test_values_related_m2m_itself(self): tournament = await Tournament.create(name="New Tournament") event = await Event.create(name="Test", tournament_id=tournament.id) team = await Team.create(name="Some Team") await event.participants.add(team) with self.assertRaisesRegex( ValueError, 'Selecting relation "participants" is not possible'): await Event.filter(name="Test").values("name", "participants") async def test_values_list_related_m2m_itself(self): tournament = await Tournament.create(name="New Tournament") event = await Event.create(name="Test", tournament_id=tournament.id) team = await Team.create(name="Some Team") await event.participants.add(team) with self.assertRaisesRegex( ValueError, 'Selecting relation "participants" is not possible'): await Event.filter(name="Test").values_list("name", "participants") async def test_values_bad_key(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Event"'): await Event.filter(name="Test").values("name", "neem") async def test_values_list_bad_key(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex(FieldError, 'Unknown field "neem" for model "Event"'): await Event.filter(name="Test").values_list("name", "neem") async def test_values_related_bad_key(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( FieldError, 'Unknown field "neem" for model "Tournament"'): await Event.filter(name="Test").values("name", "tournament__neem") async def test_values_list_related_bad_key(self): tournament = await Tournament.create(name="New Tournament") await Event.create(name="Test", tournament_id=tournament.id) with self.assertRaisesRegex( FieldError, 'Unknown field "neem" for model "Tournament"'): await Event.filter(name="Test" ).values_list("name", "tournament__neem") @test.requireCapability(dialect="!mssql") async def test_values_list_annotations_length(self): await Tournament.create(name="Championship") await Tournament.create(name="Super Bowl") tournaments = await Tournament.annotate(name_length=Length("name") ).values_list( "name", "name_length") self.assertListSortEqual(tournaments, [("Championship", 12), ("Super Bowl", 10)]) @test.requireCapability(dialect=NotEQ("mssql")) async def test_values_annotations_length(self): await Tournament.create(name="Championship") await Tournament.create(name="Super Bowl") tournaments = await Tournament.annotate(name_slength=Length("name") ).values( "name", "name_slength") self.assertListSortEqual( tournaments, [ { "name": "Championship", "name_slength": 12 }, { "name": "Super Bowl", "name_slength": 10 }, ], sorted_key="name", ) async def test_values_list_annotations_trim(self): await Tournament.create(name=" x") await Tournament.create(name=" y ") tournaments = await Tournament.annotate(name_trim=Trim("name") ).values_list( "name", "name_trim") self.assertListSortEqual(tournaments, [(" x", "x"), (" y ", "y")]) async def test_values_annotations_trim(self): await Tournament.create(name=" x") await Tournament.create(name=" y ") tournaments = await Tournament.annotate(name_trim=Trim("name") ).values("name", "name_trim") self.assertListSortEqual( tournaments, [{ "name": " x", "name_trim": "x" }, { "name": " y ", "name_trim": "y" }], sorted_key="name", )
class StraightFieldTests(test.TestCase): def setUp(self) -> None: self.model = StraightFields async def test_get_all(self): obj1 = await self.model.create(chars="aaa") self.assertIsNotNone(obj1.eyedee, str(dir(obj1))) obj2 = await self.model.create(chars="bbb") objs = await self.model.all() self.assertListSortEqual(objs, [obj1, obj2]) async def test_get_by_pk(self): obj = await self.model.create(chars="aaa") obj1 = await self.model.get(eyedee=obj.eyedee) self.assertEqual(obj, obj1) async def test_get_by_chars(self): obj = await self.model.create(chars="aaa") obj1 = await self.model.get(chars="aaa") self.assertEqual(obj, obj1) async def test_get_fk_forward_fetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj2a = await self.model.get(eyedee=obj2.eyedee) await obj2a.fetch_related("fk") self.assertEqual(obj2, obj2a) self.assertEqual(obj1, obj2a.fk) async def test_get_fk_forward_prefetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj2a = await self.model.get(eyedee=obj2.eyedee).prefetch_related("fk") self.assertEqual(obj2, obj2a) self.assertEqual(obj1, obj2a.fk) async def test_get_fk_reverse_await(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee) self.assertListSortEqual(await obj1a.fkrev, [obj2, obj3]) async def test_get_fk_reverse_filter(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) objs = await self.model.filter(fk=obj1) self.assertListSortEqual(objs, [obj2, obj3]) async def test_get_fk_reverse_async_for(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee) objs = [] async for obj in obj1a.fkrev: objs.append(obj) self.assertListSortEqual(objs, [obj2, obj3]) async def test_get_fk_reverse_fetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee) await obj1a.fetch_related("fkrev") self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3]) async def test_get_fk_reverse_prefetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb", fk=obj1) obj3 = await self.model.create(chars="ccc", fk=obj1) obj1a = await self.model.get(eyedee=obj1.eyedee ).prefetch_related("fkrev") self.assertListSortEqual(list(obj1a.fkrev), [obj2, obj3]) async def test_get_m2m_forward_await(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee) self.assertEqual(await obj2a.rel_from, [obj1]) obj1a = await self.model.get(eyedee=obj1.eyedee) self.assertEqual(await obj1a.rel_to, [obj2]) async def test_get_m2m_reverse_await(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj2.rel_from.add(obj1) obj2a = await self.model.get(pk=obj2.eyedee) self.assertEqual(await obj2a.rel_from, [obj1]) obj1a = await self.model.get(eyedee=obj1.pk) self.assertEqual(await obj1a.rel_to, [obj2]) async def test_get_m2m_filter(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) rel_froms = await self.model.filter(rel_from=obj1) self.assertEqual(rel_froms, [obj2]) rel_tos = await self.model.filter(rel_to=obj2) self.assertEqual(rel_tos, [obj1]) async def test_get_m2m_forward_fetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee) await obj2a.fetch_related("rel_from") self.assertEqual(list(obj2a.rel_from), [obj1]) async def test_get_m2m_reverse_fetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj1a = await self.model.get(eyedee=obj1.eyedee) await obj1a.fetch_related("rel_to") self.assertEqual(list(obj1a.rel_to), [obj2]) async def test_get_m2m_forward_prefetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj2a = await self.model.get(eyedee=obj2.eyedee ).prefetch_related("rel_from") self.assertEqual(list(obj2a.rel_from), [obj1]) async def test_get_m2m_reverse_prefetch_related(self): obj1 = await self.model.create(chars="aaa") obj2 = await self.model.create(chars="bbb") await obj1.rel_to.add(obj2) obj1a = await self.model.get(eyedee=obj1.eyedee ).prefetch_related("rel_to") self.assertEqual(list(obj1a.rel_to), [obj2]) async def test_values_reverse_relation(self): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) obj1a = await self.model.filter(chars="aaa").values("fkrev__chars") self.assertEqual(obj1a[0]["fkrev__chars"], "bbb") async def test_f_expression(self): obj1 = await self.model.create(chars="aaa") await self.model.filter(eyedee=obj1.eyedee).update(chars=F("blip")) obj2 = await self.model.get(eyedee=obj1.eyedee) self.assertEqual(obj2.chars, "BLIP") async def test_function(self): obj1 = await self.model.create(chars=" aaa ") await self.model.filter(eyedee=obj1.eyedee).update(chars=Trim("chars")) obj2 = await self.model.get(eyedee=obj1.eyedee) self.assertEqual(obj2.chars, "aaa") async def test_aggregation_with_filter(self): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) await self.model.create(chars="ccc", fk=obj1) obj = (await self.model.filter(chars="aaa").annotate( all=Count("fkrev", _filter=Q(chars="aaa")), one=Count("fkrev", _filter=Q(fkrev__chars="bbb")), no=Count("fkrev", _filter=Q(fkrev__chars="aaa")), ).first()) self.assertEqual(obj.all, 2) self.assertEqual(obj.one, 1) self.assertEqual(obj.no, 0) async def test_filter_by_aggregation_field_coalesce(self): await self.model.create(chars="aaa", nullable="null") await self.model.create(chars="bbb") objs = await self.model.annotate(null=Coalesce("nullable", "null") ).filter(null="null") self.assertEqual(len(objs), 2) self.assertSetEqual({(o.chars, o.null) for o in objs}, {("aaa", "null"), ("bbb", "null")}) async def test_filter_by_aggregation_field_count(self): await self.model.create(chars="aaa") await self.model.create(chars="bbb") obj = await self.model.annotate(chars_count=Count("chars") ).filter(chars_count=1, chars="aaa") self.assertEqual(len(obj), 1) self.assertEqual(obj[0].chars, "aaa") @test.requireCapability(dialect=NotEQ("mssql")) async def test_filter_by_aggregation_field_length(self): await self.model.create(chars="aaa") await self.model.create(chars="bbbbb") obj = await self.model.annotate(chars_length=Length("chars") ).filter(chars_length=3) self.assertEqual(len(obj), 1) self.assertEqual(obj[0].chars_length, 3) async def test_filter_by_aggregation_field_lower(self): await self.model.create(chars="AaA") obj = await self.model.annotate(chars_lower=Lower("chars") ).filter(chars_lower="aaa") self.assertEqual(len(obj), 1) self.assertEqual(obj[0].chars_lower, "aaa") async def test_filter_by_aggregation_field_trim(self): await self.model.create(chars=" aaa ") obj = await self.model.annotate(chars_trim=Trim("chars") ).filter(chars_trim="aaa") self.assertEqual(len(obj), 1) self.assertEqual(obj[0].chars_trim, "aaa") async def test_filter_by_aggregation_field_upper(self): await self.model.create(chars="aAa") obj = await self.model.annotate(chars_upper=Upper("chars") ).filter(chars_upper="AAA") self.assertEqual(len(obj), 1) self.assertEqual(obj[0].chars_upper, "AAA") async def test_values_by_fk(self): obj1 = await self.model.create(chars="aaa") await self.model.create(chars="bbb", fk=obj1) obj = await self.model.filter(chars="bbb").values("fk__chars") self.assertEqual(obj, [{"fk__chars": "aaa"}])
class TestUpdate(test.TestCase): async def test_update(self): await Tournament.create(name="1") await Tournament.create(name="3") rows_affected = await Tournament.all().update(name="2") self.assertEqual(rows_affected, 2) tournament = await Tournament.first() self.assertEqual(tournament.name, "2") async def test_bulk_update(self): objs = [ await Tournament.create(name="1"), await Tournament.create(name="2") ] objs[0].name = "0" objs[1].name = "1" rows_affected = await Tournament.bulk_update(objs, fields=["name"], batch_size=100) self.assertEqual(rows_affected, 2) self.assertEqual((await Tournament.get(pk=objs[0].pk)).name, "0") self.assertEqual((await Tournament.get(pk=objs[1].pk)).name, "1") async def test_bulk_update_datetime(self): objs = [ await DatetimeFields.create( datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), await DatetimeFields.create( datetime=datetime(2021, 1, 1, tzinfo=pytz.utc)), ] t0 = datetime(2021, 1, 2, tzinfo=pytz.utc) t1 = datetime(2021, 1, 3, tzinfo=pytz.utc) objs[0].datetime = t0 objs[1].datetime = t1 rows_affected = await DatetimeFields.bulk_update(objs, fields=["datetime"]) self.assertEqual(rows_affected, 2) self.assertEqual((await DatetimeFields.get(pk=objs[0].pk)).datetime, t0) self.assertEqual((await DatetimeFields.get(pk=objs[1].pk)).datetime, t1) async def test_bulk_update_pk_uuid(self): objs = [ await UUIDFields.create(data=uuid.uuid4()), await UUIDFields.create(data=uuid.uuid4()), ] objs[0].data = uuid.uuid4() objs[1].data = uuid.uuid4() rows_affected = await UUIDFields.bulk_update(objs, fields=["data"]) self.assertEqual(rows_affected, 2) self.assertEqual((await UUIDFields.get(pk=objs[0].pk)).data, objs[0].data) self.assertEqual((await UUIDFields.get(pk=objs[1].pk)).data, objs[1].data) async def test_bulk_update_json_value(self): objs = [ await JSONFields.create(data={}), await JSONFields.create(data={}), ] objs[0].data = [0] objs[1].data = {"a": 1} rows_affected = await JSONFields.bulk_update(objs, fields=["data"]) self.assertEqual(rows_affected, 2) self.assertEqual((await JSONFields.get(pk=objs[0].pk)).data, objs[0].data) self.assertEqual((await JSONFields.get(pk=objs[1].pk)).data, objs[1].data) @test.requireCapability(dialect=NotEQ("mssql")) async def test_bulk_update_smallint_none(self): objs = [ await SmallIntFields.create(smallintnum=1, smallintnum_null=1), await SmallIntFields.create(smallintnum=2, smallintnum_null=2), ] objs[0].smallintnum_null = None objs[1].smallintnum_null = None rows_affected = await SmallIntFields.bulk_update( objs, fields=["smallintnum_null"]) self.assertEqual(rows_affected, 2) self.assertEqual( (await SmallIntFields.get(pk=objs[0].pk)).smallintnum_null, None) self.assertEqual( (await SmallIntFields.get(pk=objs[1].pk)).smallintnum_null, None) async def test_update_auto_now(self): obj = await DefaultUpdate.create() now = datetime.now() updated_at = now - timedelta(days=1) await DefaultUpdate.filter(pk=obj.pk).update(updated_at=updated_at) obj1 = await DefaultUpdate.get(pk=obj.pk) self.assertEqual(obj1.updated_at.date(), updated_at.date()) async def test_update_relation(self): tournament_first = await Tournament.create(name="1") tournament_second = await Tournament.create(name="2") await Event.create(name="1", tournament=tournament_first) await Event.all().update(tournament=tournament_second) event = await Event.first() self.assertEqual(event.tournament_id, tournament_second.id) @test.requireCapability(dialect=In("mysql", "sqlite")) async def test_update_with_custom_function(self): class JsonSet(Function): def __init__(self, field: F, expression: str, value: Any): super().__init__("JSON_SET", field, expression, value) json = await JSONFields.create(data={}) self.assertEqual(json.data_default, {"a": 1}) json.data_default = JsonSet(F("data_default"), "$.a", 2) await json.save() json_update = await JSONFields.get(pk=json.pk) self.assertEqual(json_update.data_default, {"a": 2}) await JSONFields.filter(pk=json.pk).update( data_default=JsonSet(F("data_default"), "$.a", 3)) json_update = await JSONFields.get(pk=json.pk) self.assertEqual(json_update.data_default, {"a": 3}) async def test_refresh_from_db(self): int_field = await IntFields.create(intnum=1, intnum_null=2) int_field_in_db = await IntFields.get(pk=int_field.pk) int_field_in_db.intnum = F("intnum") + 1 await int_field_in_db.save(update_fields=["intnum"]) self.assertIsNot(int_field_in_db.intnum, 2) self.assertIs(int_field_in_db.intnum_null, 2) await int_field_in_db.refresh_from_db(fields=["intnum"]) self.assertIs(int_field_in_db.intnum, 2) self.assertIs(int_field_in_db.intnum_null, 2) int_field_in_db.intnum = F("intnum") + 1 await int_field_in_db.save() self.assertIsNot(int_field_in_db.intnum, 3) self.assertIs(int_field_in_db.intnum_null, 2) await int_field_in_db.refresh_from_db() self.assertIs(int_field_in_db.intnum, 3) self.assertIs(int_field_in_db.intnum_null, 2) @test.requireCapability(support_update_limit_order_by=True) async def test_update_with_limit_ordering(self): await Tournament.create(name="1") t2 = await Tournament.create(name="1") await Tournament.filter(name="1" ).limit(1).order_by("-id").update(name="2") self.assertIs((await Tournament.get(pk=t2.pk)).name, "2") self.assertEqual(await Tournament.filter(name="1").count(), 1)