コード例 #1
0
ファイル: test_mapper.py プロジェクト: Debilski/xdapy
class TestTypeMagic(unittest.TestCase):
    def setUp(self):
        self.connection = Connection.test()
        self.connection.create_tables()
        self.m = Mapper(self.connection)

    def tearDown(self):
        self.connection.drop_tables()
        # need to dispose manually to avoid too many connections error
        self.connection.engine.dispose()

    def test_rebrand(self):
        class E0(Entity):
            declared_params = {
                "q": "integer"
            }

        class E1(Entity):
            declared_params = {
                "q": "integer",
                "q1": "integer"
            }

        e0 = E0(q=100)
        self.m.save(e0)

        self.m.rebrand(E0, E1)

        # trying to load something from e0 now raises
        self.assertRaises(DetachedInstanceError, lambda: e0.type)
        self.assertRaises(DetachedInstanceError, lambda: dict(e0.params)) # lazy, so we call dict

        e1 = self.m.find_roots()[0]

        #self.m.session.expunge(e1)

        import pdb
        #pdb.set_trace()

        self.assertEqual(e1._type, E1.__name__)
        self.assertEqual(e1.params["q"], 100)

        e1.params["q"] = 0
        e1.params["q1"] = 0

    def test_multi_step_rebrand(self):
        class MyEntity1(Entity):
            declared_params = {
                "name": "string",
                "version": "integer"
            }

        class MyEntity2a(Entity):
            declared_params = {
                "name": "string",
                "version": "integer",
                "version_string": "string",
                "release_date": "date"
            }

        class MyEntity2b(Entity):
            declared_params = {
                "name": "string",
                "version_string": "string",
                "release_date": "date"
            }

        class MyEntity3(Entity):
            declared_params = {
                "name": "string",
                "version": "string",
                "release_date": "date"
            }

        orig_entities = [
            MyEntity1(name="abc", version=1),
            MyEntity1(name="def", version=2),
            MyEntity1(name="ghi", version=3),
            MyEntity1(name="jkl", version=4),
            MyEntity1(name="mno", version=5),
            MyEntity1(name="pqr", version=6),
            MyEntity1(name="stu", version=7)
        ]

        self.m.save(*orig_entities)

        def change_version_string_add_date(e, params):
            params["version_string"] = str(params["version"])
            del params["version"]
            params["release_date"] = datetime.date.today()
            return params

        def change_version_string(e, params):
            params["version"] = params["version_string"]
            del params["version_string"]
            return params

        self.assertEqual(len(self.m.find(MyEntity1).all()), 7)
        self.assertEqual(len(self.m.find(MyEntity3).all()), 0)

        version_1 = [e.params["version"] for e in self.m.find(MyEntity1)]
        self.assertEqual(len(version_1), 7)
        self.assertTrue(all(type(v)==int for v in version_1))

        # we cannot rebrand directly:
        self.assertRaises(ValueError, self.m.rebrand, MyEntity1, MyEntity3)

        # do the rebranding
        self.m.rebrand(MyEntity1, MyEntity2a, after=change_version_string_add_date)

        # still won’t work to do this directly
        self.assertRaises(ValueError, self.m.rebrand, MyEntity2a, MyEntity3)
        self.m.rebrand(MyEntity2a, MyEntity2b)

        self.m.rebrand(MyEntity2b, MyEntity3, after=change_version_string)

        self.assertEqual(len(self.m.find(MyEntity1).all()), 0)
        self.assertEqual(len(self.m.find(MyEntity3).all()), 7)

        version_3 = [e.params["version"] for e in self.m.find(MyEntity3)]
        self.assertEqual(len(version_3), 7)
        self.assertTrue(all(isinstance(v, basestring) for v in version_3))