class TestTypeMagic(unittest.TestCase): def setUp(self): self.connection = Connection.test() self.connection.create_tables() self.m = Mapper(self.connection) def tearDown(self): self.connection.drop_tables() # need to dispose manually to avoid too many connections error self.connection.engine.dispose() def test_rebrand(self): class E0(Entity): declared_params = { "q": "integer" } class E1(Entity): declared_params = { "q": "integer", "q1": "integer" } e0 = E0(q=100) self.m.save(e0) self.m.rebrand(E0, E1) # trying to load something from e0 now raises self.assertRaises(DetachedInstanceError, lambda: e0.type) self.assertRaises(DetachedInstanceError, lambda: dict(e0.params)) # lazy, so we call dict e1 = self.m.find_roots()[0] #self.m.session.expunge(e1) import pdb #pdb.set_trace() self.assertEqual(e1._type, E1.__name__) self.assertEqual(e1.params["q"], 100) e1.params["q"] = 0 e1.params["q1"] = 0 def test_multi_step_rebrand(self): class MyEntity1(Entity): declared_params = { "name": "string", "version": "integer" } class MyEntity2a(Entity): declared_params = { "name": "string", "version": "integer", "version_string": "string", "release_date": "date" } class MyEntity2b(Entity): declared_params = { "name": "string", "version_string": "string", "release_date": "date" } class MyEntity3(Entity): declared_params = { "name": "string", "version": "string", "release_date": "date" } orig_entities = [ MyEntity1(name="abc", version=1), MyEntity1(name="def", version=2), MyEntity1(name="ghi", version=3), MyEntity1(name="jkl", version=4), MyEntity1(name="mno", version=5), MyEntity1(name="pqr", version=6), MyEntity1(name="stu", version=7) ] self.m.save(*orig_entities) def change_version_string_add_date(e, params): params["version_string"] = str(params["version"]) del params["version"] params["release_date"] = datetime.date.today() return params def change_version_string(e, params): params["version"] = params["version_string"] del params["version_string"] return params self.assertEqual(len(self.m.find(MyEntity1).all()), 7) self.assertEqual(len(self.m.find(MyEntity3).all()), 0) version_1 = [e.params["version"] for e in self.m.find(MyEntity1)] self.assertEqual(len(version_1), 7) self.assertTrue(all(type(v)==int for v in version_1)) # we cannot rebrand directly: self.assertRaises(ValueError, self.m.rebrand, MyEntity1, MyEntity3) # do the rebranding self.m.rebrand(MyEntity1, MyEntity2a, after=change_version_string_add_date) # still won’t work to do this directly self.assertRaises(ValueError, self.m.rebrand, MyEntity2a, MyEntity3) self.m.rebrand(MyEntity2a, MyEntity2b) self.m.rebrand(MyEntity2b, MyEntity3, after=change_version_string) self.assertEqual(len(self.m.find(MyEntity1).all()), 0) self.assertEqual(len(self.m.find(MyEntity3).all()), 7) version_3 = [e.params["version"] for e in self.m.find(MyEntity3)] self.assertEqual(len(version_3), 7) self.assertTrue(all(isinstance(v, basestring) for v in version_3))