示例#1
0
def perform_data_update(dbfile):
    new_tmp_store = Store(create_database(GLSettings.make_db_uri(dbfile)))
    try:
        db_perform_data_update(new_tmp_store)
        new_tmp_store.commit()
    except:
        new_tmp_store.rollback()
        raise
    finally:
        new_tmp_store.close()
 def rollback(self):
     if self.transaction:
         self.transaction = None
         return self.rollback.commit()
     result = Store.rollback(self)
     #Store.reset(self)
     return result
示例#3
0
文件: schema.py 项目: petrhosek/storm
class SchemaTest(MockerTestCase):

    def setUp(self):
        super(SchemaTest, self).setUp()
        self.database = create_database("sqlite:///%s" % self.makeFile())
        self.store = Store(self.database)

        self._package_dirs = set()
        self._package_names = set()
        self.package = self.create_package(self.makeDir(), "patch_package")
        import patch_package

        creates = ["CREATE TABLE person (id INTEGER, name TEXT)"]
        drops = ["DROP TABLE person"]
        deletes = ["DELETE FROM person"]

        self.schema = Schema(creates, drops, deletes, patch_package)

    def tearDown(self):
        for package_dir in self._package_dirs:
            sys.path.remove(package_dir)

        for name in list(sys.modules):
            if name in self._package_names:
                del sys.modules[name]
            elif filter(
                None,
                [name.startswith("%s." % x) for x in self._package_names]):
                del sys.modules[name]

        super(SchemaTest, self).tearDown()

    def create_package(self, base_dir, name, init_module=None):
        """Create a Python package.

        Packages created using this method will be removed from L{sys.path}
        and L{sys.modules} during L{tearDown}.

        @param package_dir: The directory in which to create the new package.
        @param name: The name of the package.
        @param init_module: Optionally, the text to include in the __init__.py
            file.
        @return: A L{Package} instance that can be used to create modules.
        """
        package_dir = os.path.join(base_dir, name)
        self._package_names.add(name)
        os.makedirs(package_dir)

        file = open(os.path.join(package_dir, "__init__.py"), "w")
        if init_module:
            file.write(init_module)
        file.close()
        sys.path.append(base_dir)
        self._package_dirs.add(base_dir)

        return Package(package_dir, name)

    def test_create(self):
        """
        L{Schema.create} can be used to create the tables of a L{Store}.
        """
        self.assertRaises(StormError,
                          self.store.execute, "SELECT * FROM person")
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_create_with_autocommit_off(self):
        """
        L{Schema.autocommit} can be used to turn automatic commits off.
        """
        self.schema.autocommit(False)
        self.schema.create(self.store)
        self.store.rollback()
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM patch")

    def test_drop(self):
        """
        L{Schema.drop} can be used to drop the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])
        self.schema.drop(self.store)
        self.assertRaises(StormError,
                          self.store.execute, "SELECT * FROM person")

    def test_delete(self):
        """
        L{Schema.delete} can be used to clear the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane")])
        self.schema.delete(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_creates_schema(self):
        """
        L{Schema.upgrade} creates a schema from scratch if no exist, and is
        effectively equivalent to L{Schema.create} in such case.
        """
        self.assertRaises(StormError,
                          self.store.execute, "SELECT * FROM person")
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_marks_patches_applied(self):
        """
        L{Schema.upgrade} updates the patch table after applying the needed
        patches.
        """
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        statement = "SELECT * FROM patch"
        self.assertRaises(StormError, self.store.execute, statement)
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM patch")),
                          [(1,)])

    def test_upgrade_applies_patches(self):
        """
        L{Schema.upgrade} executes the needed patches, that typically modify
        the existing schema.
        """
        self.schema.create(self.store)
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        self.schema.upgrade(self.store)
        self.store.execute(
            "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane", u"123")])
示例#4
0
class PatchApplierTest(MockerTestCase):
    def setUp(self):
        super(PatchApplierTest, self).setUp()

        self.patchdir = self.makeDir()
        self.pkgdir = os.path.join(self.patchdir, "mypackage")
        os.makedirs(self.pkgdir)

        f = open(os.path.join(self.pkgdir, "__init__.py"), "w")
        f.write("shared_data = []")
        f.close()

        # Order of creation here is important to try to screw up the
        # patch ordering, as os.listdir returns in order of mtime (or
        # something).
        for pname, data in [("patch_380.py", patch_test_1),
                            ("patch_42.py", patch_test_0)]:
            self.add_module(pname, data)

        sys.path.append(self.patchdir)

        self.filename = self.makeFile()
        self.uri = "sqlite:///%s" % self.filename
        self.store = Store(create_database(self.uri))

        self.store.execute("CREATE TABLE patch "
                           "(version INTEGER NOT NULL PRIMARY KEY)")

        self.assertFalse(self.store.get(Patch, (42)))
        self.assertFalse(self.store.get(Patch, (380)))

        import mypackage
        self.mypackage = mypackage
        self.patch_set = PatchSet(mypackage)

        # Create another connection just to keep track of the state of the
        # whole transaction manager.  See the assertion functions below.
        self.another_store = Store(create_database("sqlite:"))
        self.another_store.execute("CREATE TABLE test (id INT)")
        self.another_store.commit()
        self.prepare_for_transaction_check()

        class Committer(object):
            def commit(committer):
                self.store.commit()
                self.another_store.commit()

            def rollback(committer):
                self.store.rollback()
                self.another_store.rollback()

        self.committer = Committer()
        self.patch_applier = PatchApplier(self.store, self.patch_set,
                                          self.committer)

    def tearDown(self):
        super(PatchApplierTest, self).tearDown()
        self.committer.rollback()
        sys.path.remove(self.patchdir)
        for name in list(sys.modules):
            if name == "mypackage" or name.startswith("mypackage."):
                del sys.modules[name]

    def add_module(self, module_filename, contents):
        filename = os.path.join(self.pkgdir, module_filename)
        file = open(filename, "w")
        file.write(contents)
        file.close()

    def remove_all_modules(self):
        for filename in os.listdir(self.pkgdir):
            os.unlink(os.path.join(self.pkgdir, filename))

    def prepare_for_transaction_check(self):
        self.another_store.execute("DELETE FROM test")
        self.another_store.execute("INSERT INTO test VALUES (1)")

    def assert_transaction_committed(self):
        self.another_store.rollback()
        result = self.another_store.execute("SELECT * FROM test").get_one()
        self.assertEquals(result, (1, ),
                          "Transaction manager wasn't committed.")

    def assert_transaction_aborted(self):
        self.another_store.commit()
        result = self.another_store.execute("SELECT * FROM test").get_one()
        self.assertEquals(result, None, "Transaction manager wasn't aborted.")

    def test_apply(self):
        """
        L{PatchApplier.apply} executes the patch with the given version.
        """
        self.patch_applier.apply(42)

        x = getattr(self.mypackage, "patch_42").x
        self.assertEquals(x, 42)
        self.assertTrue(self.store.get(Patch, (42)))
        self.assertTrue("mypackage.patch_42" in sys.modules)

        self.assert_transaction_committed()

    def test_apply_with_patch_directory(self):
        """
        If the given L{PatchSet} uses sub-level patches, then the
        L{PatchApplier.apply} method will look at the per-patch directory and
        apply the relevant sub-level patch.
        """
        path = os.path.join(self.pkgdir, "patch_99")
        self.makeDir(path=path)
        self.makeFile(content="", path=os.path.join(path, "__init__.py"))
        self.makeFile(content=patch_test_0, path=os.path.join(path, "foo.py"))
        self.patch_set._sub_level = "foo"
        self.add_module("patch_99/foo.py", patch_test_0)
        self.patch_applier.apply(99)
        self.assertTrue(self.store.get(Patch, (99)))

    def test_apply_all(self):
        """
        L{PatchApplier.apply_all} executes all unapplied patches.
        """
        self.patch_applier.apply_all()

        self.assertTrue("mypackage.patch_42" in sys.modules)
        self.assertTrue("mypackage.patch_380" in sys.modules)

        x = getattr(self.mypackage, "patch_42").x
        y = getattr(self.mypackage, "patch_380").y

        self.assertEquals(x, 42)
        self.assertEquals(y, 380)

        self.assert_transaction_committed()

    def test_apply_exploding_patch(self):
        """
        L{PatchApplier.apply} aborts the transaction if the patch fails.
        """
        self.remove_all_modules()
        self.add_module("patch_666.py", patch_explosion)
        self.assertRaises(StormError, self.patch_applier.apply, 666)

        self.assert_transaction_aborted()

    def test_wb_apply_all_exploding_patch(self):
        """
        When a patch explodes the store is rolled back to make sure
        that any changes the patch made to the database are removed.
        Any other patches that have been applied successfully before
        it should not be rolled back.  Any patches pending after the
        exploding patch should remain unapplied.
        """
        self.add_module("patch_666.py", patch_explosion)
        self.add_module("patch_667.py", patch_after_explosion)
        self.assertEquals(list(self.patch_applier.get_unapplied_versions()),
                          [42, 380, 666, 667])
        self.assertRaises(StormError, self.patch_applier.apply_all)
        self.assertEquals(list(self.patch_applier.get_unapplied_versions()),
                          [666, 667])

    def test_mark_applied(self):
        """
        L{PatchApplier.mark} marks a patch has applied by inserting a new row
        in the patch table.
        """
        self.patch_applier.mark_applied(42)

        self.assertFalse("mypackage.patch_42" in sys.modules)
        self.assertFalse("mypackage.patch_380" in sys.modules)

        self.assertTrue(self.store.get(Patch, 42))
        self.assertFalse(self.store.get(Patch, 380))

        self.assert_transaction_committed()

    def test_mark_applied_all(self):
        """
        L{PatchApplier.mark_applied_all} marks all pending patches as applied.
        """
        self.patch_applier.mark_applied_all()

        self.assertFalse("mypackage.patch_42" in sys.modules)
        self.assertFalse("mypackage.patch_380" in sys.modules)

        self.assertTrue(self.store.get(Patch, 42))
        self.assertTrue(self.store.get(Patch, 380))

        self.assert_transaction_committed()

    def test_application_order(self):
        """
        L{PatchApplier.apply_all} applies the patches in increasing version
        order.
        """
        self.patch_applier.apply_all()
        self.assertEquals(self.mypackage.shared_data, [42, 380])

    def test_has_pending_patches(self):
        """
        L{PatchApplier.has_pending_patches} returns C{True} if there are
        patches to be applied, C{False} otherwise.
        """
        self.assertTrue(self.patch_applier.has_pending_patches())
        self.patch_applier.apply_all()
        self.assertFalse(self.patch_applier.has_pending_patches())

    def test_abort_if_unknown_patches(self):
        """
        L{PatchApplier.mark_applied} raises and error if the patch table
        contains patches without a matching file in the patch module.
        """
        self.patch_applier.mark_applied(381)
        self.assertRaises(UnknownPatchError, self.patch_applier.apply_all)

    def test_get_unknown_patch_versions(self):
        """
        L{PatchApplier.get_unknown_patch_versions} returns the versions of all
        unapplied patches.
        """
        patches = [Patch(42), Patch(380), Patch(381)]
        my_store = MockPatchStore("database", patches=patches)
        patch_applier = PatchApplier(my_store, self.mypackage)
        self.assertEqual(set([381]),
                         patch_applier.get_unknown_patch_versions())

    def test_no_unknown_patch_versions(self):
        """
        L{PatchApplier.get_unknown_patch_versions} returns an empty set if
        no patches are unapplied.
        """
        patches = [Patch(42), Patch(380)]
        my_store = MockPatchStore("database", patches=patches)
        patch_applier = PatchApplier(my_store, self.mypackage)
        self.assertEqual(set(), patch_applier.get_unknown_patch_versions())

    def test_patch_with_incorrect_apply(self):
        """
        L{PatchApplier.apply_all} raises an error as soon as one of the patches
        to be applied fails.
        """
        self.add_module("patch_999.py", patch_no_args_apply)
        try:
            self.patch_applier.apply_all()
        except BadPatchError as e:
            self.assertTrue("mypackage/patch_999.py" in str(e))
            self.assertTrue("takes no arguments" in str(e))
            self.assertTrue("TypeError" in str(e))
        else:
            self.fail("BadPatchError not raised")

    def test_patch_with_missing_apply(self):
        """
        L{PatchApplier.apply_all} raises an error if one of the patches to
        to be applied has no 'apply' function defined.
        """
        self.add_module("patch_999.py", patch_missing_apply)
        try:
            self.patch_applier.apply_all()
        except BadPatchError as e:
            self.assertTrue("mypackage/patch_999.py" in str(e))
            self.assertTrue("no attribute" in str(e))
            self.assertTrue("AttributeError" in str(e))
        else:
            self.fail("BadPatchError not raised")

    def test_patch_with_syntax_error(self):
        """
        L{PatchApplier.apply_all} raises an error if one of the patches to
        to be applied contains a syntax error.
        """
        self.add_module("patch_999.py", "that's not python")
        try:
            self.patch_applier.apply_all()
        except BadPatchError as e:
            self.assertTrue(" 999 " in str(e))
            self.assertTrue("SyntaxError" in str(e))
        else:
            self.fail("BadPatchError not raised")

    def test_patch_error_includes_traceback(self):
        """
        The exception raised by L{PatchApplier.apply_all} when a patch fails
        include the relevant traceback from the patch.
        """
        self.add_module("patch_999.py", patch_name_error)
        try:
            self.patch_applier.apply_all()
        except BadPatchError as e:
            self.assertTrue("mypackage/patch_999.py" in str(e))
            self.assertTrue("NameError" in str(e))
            self.assertTrue("blah" in str(e))
            formatted = traceback.format_exc()
            self.assertTrue("# Comment" in formatted)
        else:
            self.fail("BadPatchError not raised")
示例#5
0
class SchemaTest(MockerTestCase):
    def setUp(self):
        super(SchemaTest, self).setUp()
        self.database = create_database("sqlite:///%s" % self.makeFile())
        self.store = Store(self.database)

        self._package_dirs = set()
        self._package_names = set()
        self.package = self.create_package(self.makeDir(), "patch_package")
        import patch_package

        creates = ["CREATE TABLE person (id INTEGER, name TEXT)"]
        drops = ["DROP TABLE person"]
        deletes = ["DELETE FROM person"]

        self.schema = Schema(creates, drops, deletes, patch_package)

    def tearDown(self):
        for package_dir in self._package_dirs:
            sys.path.remove(package_dir)

        for name in list(sys.modules):
            if name in self._package_names:
                del sys.modules[name]
            elif [
                    _f for _f in
                [name.startswith("%s." % x) for x in self._package_names] if _f
            ]:
                del sys.modules[name]

        super(SchemaTest, self).tearDown()

    def create_package(self, base_dir, name, init_module=None):
        """Create a Python package.

        Packages created using this method will be removed from L{sys.path}
        and L{sys.modules} during L{tearDown}.

        @param package_dir: The directory in which to create the new package.
        @param name: The name of the package.
        @param init_module: Optionally, the text to include in the __init__.py
            file.
        @return: A L{Package} instance that can be used to create modules.
        """
        package_dir = os.path.join(base_dir, name)
        self._package_names.add(name)
        os.makedirs(package_dir)

        file = open(os.path.join(package_dir, "__init__.py"), "w")
        if init_module:
            file.write(init_module)
        file.close()
        sys.path.append(base_dir)
        self._package_dirs.add(base_dir)

        return Package(package_dir, name)

    def test_check_with_missing_schema(self):
        """
        L{Schema.check} raises an exception if the given store is completely
        pristine and no schema has been applied yet. The transaction doesn't
        get rolled back so it's still usable.
        """
        self.store.execute("CREATE TABLE foo (bar INT)")
        self.assertRaises(SchemaMissingError, self.schema.check, self.store)
        self.assertIsNone(self.store.execute("SELECT 1 FROM foo").get_one())

    def test_check_with_unapplied_patches(self):
        """
        L{Schema.check} raises an exception if the given store has unapplied
        schema patches.
        """
        self.schema.create(self.store)
        contents = """
def apply(store):
    pass
"""
        self.package.create_module("patch_1.py", contents)
        self.assertRaises(UnappliedPatchesError, self.schema.check, self.store)

    def test_create(self):
        """
        L{Schema.create} can be used to create the tables of a L{Store}.
        """
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])
        # By default changes are committed
        store2 = Store(self.database)
        self.assertEquals(list(store2.execute("SELECT * FROM person")), [])

    def test_create_with_autocommit_off(self):
        """
        L{Schema.autocommit} can be used to turn automatic commits off.
        """
        self.schema.autocommit(False)
        self.schema.create(self.store)
        self.store.rollback()
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM patch")

    def test_drop(self):
        """
        L{Schema.drop} can be used to drop the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])
        self.schema.drop(self.store)
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")

    def test_drop_with_missing_patch_table(self):
        """
        L{Schema.drop} works fine even if the user's supplied statements end up
        dropping the patch table that we created.
        """
        import patch_package
        schema = Schema([], ["DROP TABLE patch"], [], patch_package)
        schema.create(self.store)
        schema.drop(self.store)
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM patch")

    def test_delete(self):
        """
        L{Schema.delete} can be used to clear the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane")])
        self.schema.delete(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_creates_schema(self):
        """
        L{Schema.upgrade} creates a schema from scratch if no exist, and is
        effectively equivalent to L{Schema.create} in such case.
        """
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_marks_patches_applied(self):
        """
        L{Schema.upgrade} updates the patch table after applying the needed
        patches.
        """
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        statement = "SELECT * FROM patch"
        self.assertRaises(StormError, self.store.execute, statement)
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM patch")),
                          [(1, )])

    def test_upgrade_applies_patches(self):
        """
        L{Schema.upgrade} executes the needed patches, that typically modify
        the existing schema.
        """
        self.schema.create(self.store)
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        self.schema.upgrade(self.store)
        self.store.execute(
            "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane", u"123")])

    def test_advance(self):
        """
        L{Schema.advance} executes the given patch version.
        """
        self.schema.create(self.store)
        contents1 = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        contents2 = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN address TEXT')
"""
        self.package.create_module("patch_1.py", contents1)
        self.package.create_module("patch_2.py", contents2)
        self.schema.advance(self.store, 1)
        self.store.execute(
            "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane", u"123")])
示例#6
0
文件: db.py 项目: scarpentier/askgod
def db_update_schema(database):
    """
        Check for pending database schema updates.
        If any are found, apply them and bump the version.
    """

    # Connect to the database
    db_store = Store(database)

    # Check if the DB schema has been loaded
    db_exists = False
    try:
        db_store.execute(Select(DBSchema.version))
        db_exists = True
    except:
        db_store.rollback()
        logging.debug("Failed to query schema table.")

    if not db_exists:
        logging.info("Creating database")
        schema_file = sorted(glob.glob("schema/schema-*.sql"))[-1]
        schema_version = schema_file.split(".")[0].split("-")[-1]
        logging.debug("Using '%s' to deploy schema '%s'" % (schema_file,
                                                            schema_version))
        with open(schema_file, "r") as fd:
            try:
                for line in fd.read().replace("\n", "").split(";"):
                    if not line:
                        continue
                    db_store.execute("%s;" % line)
                    db_commit(db_store)
                logging.info("Database created")
            except:
                logging.critical("Failed to initialize the database")
                return False

    # Get schema version
    version = db_store.execute(Select(Max(DBSchema.version))).get_one()[0]
    if not version:
        logging.critical("No schema version.")
        return False

    # Apply updates
    for update_file in sorted(glob.glob("schema/update-*.sql")):
        update_version = update_file.split(".")[0].split("-")[-1]
        if int(update_version) > version:
            logging.info("Using '%s' to deploy update '%s'" % (update_file,
                                                               update_version))
            with open(update_file, "r") as fd:
                try:
                    for line in fd.read().replace("\n", "").split(";"):
                        if not line:
                            continue
                        db_store.execute("%s;" % line)
                        db_commit(db_store)
                except:
                    logging.critical("Failed to load schema update")
                    return False

    # Get schema version
    new_version = db_store.execute(Select(Max(DBSchema.version))).get_one()[0]
    if new_version > version:
        logging.info("Database schema successfuly updated from '%s' to '%s'" %
                     (version, new_version))

    db_store.close()
示例#7
0
class TestChangeTracker(object):


    class A(object):
        __storm_table__ = 'testob'
        changehistory = ChangeHistory.configure("history")
        clt = ChangeTracker(changehistory)
        id = Int(primary=1)
        textval = Unicode(validator=clt)
        intval = Int(validator=clt)

    def setUp(self):
        database = create_database('sqlite:')
        self.store = Store(database)
        self.store.execute("""
            CREATE table history (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                ref_class VARCHAR(200),
                ref_pk VARCHAR(200),
                ref_attr VARCHAR(200),
                new_value VARCHAR(200),
                ctime DATETIME,
                cuser INT
            )
        """)
        self.store.execute("""
            CREATE TABLE testob (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                textval VARCHAR(200),
                intval INT,
                dateval DATETIME
            )""")

    def tearDown(self):
        self.store.rollback()

    def test_calls_next_validator(self):
        clt = ChangeTracker(ChangeHistory.configure("history"), next_validator = lambda ob, attr, v: v*2)

        class B(self.A):
            textval = Unicode(validator=clt)

        b = B()
        b.textval = u'bork'
        assert b.textval == u'borkbork'

    def test_adds_log_entries(self):

        class B(self.A):
            clt = ChangeTracker(ChangeHistory.configure("history"))
            textval = Unicode(validator=clt)

        b = self.store.add(B())
        b.textval = u'pointless'
        b.textval = u'aimless'
        changes = list(self.store.find(b.clt.change_cls))
        assert_equal(len(changes), 2)
        assert_equal(changes[0].new_value, 'pointless')
        assert_equal(changes[1].new_value, 'aimless')

    def test_value_type_preserved(self):
        a = self.store.add(self.A())
        a.textval = u'one'
        a.intval = 1
        changes = list(self.store.find(a.clt.change_cls))
        assert_equal(type(changes[0].new_value), unicode)
        assert_equal(type(changes[1].new_value), int)

    def test_ctime_set(self):
        start = datetime.now()
        a = self.store.add(self.A())
        a.textval = u'x'
        changes = list(self.store.find(a.clt.change_cls))
        assert_equal(type(changes[0].ctime), datetime)
        assert start < changes[0].ctime < datetime.now()

    def test_cuser_set(self):
        def getuser():
            return u'Fred'

        history = ChangeHistory.configure("history", getuser=getuser, usertype=Unicode)
        class B(self.A):
            textval = Unicode(validator=ChangeTracker(history))

        b = self.store.add(B())
        b.textval = u'foo'
        changes = self.store.find(history)
        assert_equal(changes[0].cuser, u'Fred')


    def test_changes_for_returns_change_history(self):
        a = self.store.add(self.A())
        b = self.store.add(self.A())
        a.id = 1
        a.textval = u'one'
        a.textval = u'two'
        b.id = 2
        b.textval = u'ein'
        b.textval = u'zwei'

        assert_equal([c.new_value for c in a.changehistory.changes_for(a)], [u'one', u'two'])
        assert_equal([c.new_value for c in a.changehistory.changes_for(b)], [u'ein', u'zwei'])
示例#8
0
class PatchTest(MockerTestCase):

    def setUp(self):
        super(PatchTest, self).setUp()

        self.patchdir = self.makeDir()
        self.pkgdir = os.path.join(self.patchdir, "mypackage")
        os.makedirs(self.pkgdir)

        f = open(os.path.join(self.pkgdir, "__init__.py"), "w")
        f.write("shared_data = []")
        f.close()

        # Order of creation here is important to try to screw up the
        # patch ordering, as os.listdir returns in order of mtime (or
        # something).
        for pname, data in [("patch_380.py", patch_test_1),
                            ("patch_42.py", patch_test_0)]:
            self.add_module(pname, data)

        sys.path.append(self.patchdir)

        self.filename = self.makeFile()
        self.uri = "sqlite:///%s" % self.filename
        self.store = Store(create_database(self.uri))

        self.store.execute("CREATE TABLE patch "
                           "(version INTEGER NOT NULL PRIMARY KEY)")

        self.assertFalse(self.store.get(Patch, (42)))
        self.assertFalse(self.store.get(Patch, (380)))

        import mypackage
        self.mypackage = mypackage

        # Create another connection just to keep track of the state of the
        # whole transaction manager.  See the assertion functions below.
        self.another_store = Store(create_database("sqlite:"))
        self.another_store.execute("CREATE TABLE test (id INT)")
        self.another_store.commit()
        self.prepare_for_transaction_check()

        class Committer(object):

            def commit(committer):
                self.store.commit()
                self.another_store.commit()

            def rollback(committer):
                self.store.rollback()
                self.another_store.rollback()

        self.committer = Committer()
        self.patch_applier = PatchApplier(self.store, self.mypackage,
                                          self.committer)

    def tearDown(self):
        super(PatchTest, self).tearDown()
        self.committer.rollback()
        sys.path.remove(self.patchdir)
        for name in list(sys.modules):
            if name == "mypackage" or name.startswith("mypackage."):
                del sys.modules[name]

    def add_module(self, module_filename, contents):
        filename = os.path.join(self.pkgdir, module_filename)
        file = open(filename, "w")
        file.write(contents)
        file.close()

    def remove_all_modules(self):
        for filename in os.listdir(self.pkgdir):
            os.unlink(os.path.join(self.pkgdir, filename))

    def prepare_for_transaction_check(self):
        self.another_store.execute("DELETE FROM test")
        self.another_store.execute("INSERT INTO test VALUES (1)")

    def assert_transaction_committed(self):
        self.another_store.rollback()
        result = self.another_store.execute("SELECT * FROM test").get_one()
        self.assertEquals(result, (1,),
                          "Transaction manager wasn't committed.")

    def assert_transaction_aborted(self):
        self.another_store.commit()
        result = self.another_store.execute("SELECT * FROM test").get_one()
        self.assertEquals(result, None,
                          "Transaction manager wasn't aborted.")

    def test_apply(self):
        """
        L{PatchApplier.apply} executes the patch with the given version.
        """
        self.patch_applier.apply(42)

        x = getattr(self.mypackage, "patch_42").x
        self.assertEquals(x, 42)
        self.assertTrue(self.store.get(Patch, (42)))
        self.assertTrue("mypackage.patch_42" in sys.modules)

        self.assert_transaction_committed()

    def test_apply_all(self):
        """
        L{PatchApplier.apply_all} executes all unapplied patches.
        """
        self.patch_applier.apply_all()

        self.assertTrue("mypackage.patch_42" in sys.modules)
        self.assertTrue("mypackage.patch_380" in sys.modules)

        x = getattr(self.mypackage, "patch_42").x
        y = getattr(self.mypackage, "patch_380").y

        self.assertEquals(x, 42)
        self.assertEquals(y, 380)

        self.assert_transaction_committed()

    def test_apply_exploding_patch(self):
        """
        L{PatchApplier.apply} aborts the transaction if the patch fails.
        """
        self.remove_all_modules()
        self.add_module("patch_666.py", patch_explosion)
        self.assertRaises(StormError, self.patch_applier.apply, 666)

        self.assert_transaction_aborted()

    def test_wb_apply_all_exploding_patch(self):
        """
        When a patch explodes the store is rolled back to make sure
        that any changes the patch made to the database are removed.
        Any other patches that have been applied successfully before
        it should not be rolled back.  Any patches pending after the
        exploding patch should remain unapplied.
        """
        self.add_module("patch_666.py", patch_explosion)
        self.add_module("patch_667.py", patch_after_explosion)
        self.assertEquals(list(self.patch_applier._get_unapplied_versions()),
                          [42, 380, 666, 667])
        self.assertRaises(StormError, self.patch_applier.apply_all)
        self.assertEquals(list(self.patch_applier._get_unapplied_versions()),
                          [666, 667])

    def test_mark_applied(self):
        """
        L{PatchApplier.mark} marks a patch has applied by inserting a new row
        in the patch table.
        """
        self.patch_applier.mark_applied(42)

        self.assertFalse("mypackage.patch_42" in sys.modules)
        self.assertFalse("mypackage.patch_380" in sys.modules)

        self.assertTrue(self.store.get(Patch, 42))
        self.assertFalse(self.store.get(Patch, 380))

        self.assert_transaction_committed()

    def test_mark_applied_all(self):
        """
        L{PatchApplier.mark_applied_all} marks all pending patches as applied.
        """
        self.patch_applier.mark_applied_all()

        self.assertFalse("mypackage.patch_42" in sys.modules)
        self.assertFalse("mypackage.patch_380" in sys.modules)

        self.assertTrue(self.store.get(Patch, 42))
        self.assertTrue(self.store.get(Patch, 380))

        self.assert_transaction_committed()

    def test_application_order(self):
        """
        L{PatchApplier.apply_all} applies the patches in increasing version
        order.
        """
        self.patch_applier.apply_all()
        self.assertEquals(self.mypackage.shared_data,
                          [42, 380])

    def test_has_pending_patches(self):
        """
        L{PatchApplier.has_pending_patches} returns C{True} if there are
        patches to be applied, C{False} otherwise.
        """
        self.assertTrue(self.patch_applier.has_pending_patches())
        self.patch_applier.apply_all()
        self.assertFalse(self.patch_applier.has_pending_patches())

    def test_abort_if_unknown_patches(self):
        """
        L{PatchApplier.mark_applied} raises and error if the patch table
        contains patches without a matching file in the patch module.
        """
        self.patch_applier.mark_applied(381)
        self.assertRaises(UnknownPatchError, self.patch_applier.apply_all)

    def test_get_unknown_patch_versions(self):
        """
        L{PatchApplier.get_unknown_patch_versions} returns the versions of all
        unapplied patches.
        """
        patches = [Patch(42), Patch(380), Patch(381)]
        my_store = MockPatchStore("database", patches=patches)
        patch_applier = PatchApplier(my_store, self.mypackage)
        self.assertEqual(set([381]),
                         patch_applier.get_unknown_patch_versions())

    def test_no_unknown_patch_versions(self):
        """
        L{PatchApplier.get_unknown_patch_versions} returns an empty set if
        no patches are unapplied.
        """
        patches = [Patch(42), Patch(380)]
        my_store = MockPatchStore("database", patches=patches)
        patch_applier = PatchApplier(my_store, self.mypackage)
        self.assertEqual(set(), patch_applier.get_unknown_patch_versions())

    def test_patch_with_incorrect_apply(self):
        """
        L{PatchApplier.apply_all} raises an error as soon as one of the patches
        to be applied fails.
        """
        self.add_module("patch_999.py", patch_no_args_apply)
        try:
            self.patch_applier.apply_all()
        except BadPatchError, e:
            self.assertTrue("mypackage/patch_999.py" in str(e))
            self.assertTrue("takes no arguments" in str(e))
            self.assertTrue("TypeError" in str(e))
        else:
示例#9
0
文件: schema.py 项目: saoili/storm
class SchemaTest(MockerTestCase):
    def setUp(self):
        super(SchemaTest, self).setUp()
        self.database = create_database("sqlite:///%s" % self.makeFile())
        self.store = Store(self.database)

        self._package_dirs = set()
        self._package_names = set()
        self.package = self.create_package(self.makeDir(), "patch_package")
        import patch_package

        creates = ["CREATE TABLE person (id INTEGER, name TEXT)"]
        drops = ["DROP TABLE person"]
        deletes = ["DELETE FROM person"]

        self.schema = Schema(creates, drops, deletes, patch_package)

    def tearDown(self):
        for package_dir in self._package_dirs:
            sys.path.remove(package_dir)

        for name in list(sys.modules):
            if name in self._package_names:
                del sys.modules[name]
            elif filter(
                    None,
                [name.startswith("%s." % x) for x in self._package_names]):
                del sys.modules[name]

        super(SchemaTest, self).tearDown()

    def create_package(self, base_dir, name, init_module=None):
        """Create a Python package.

        Packages created using this method will be removed from L{sys.path}
        and L{sys.modules} during L{tearDown}.

        @param package_dir: The directory in which to create the new package.
        @param name: The name of the package.
        @param init_module: Optionally, the text to include in the __init__.py
            file.
        @return: A L{Package} instance that can be used to create modules.
        """
        package_dir = os.path.join(base_dir, name)
        self._package_names.add(name)
        os.makedirs(package_dir)

        file = open(os.path.join(package_dir, "__init__.py"), "w")
        if init_module:
            file.write(init_module)
        file.close()
        sys.path.append(base_dir)
        self._package_dirs.add(base_dir)

        return Package(package_dir, name)

    def test_create(self):
        """
        L{Schema.create} can be used to create the tables of a L{Store}.
        """
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_create_with_autocommit_off(self):
        """
        L{Schema.autocommit} can be used to turn automatic commits off.
        """
        self.schema.autocommit(False)
        self.schema.create(self.store)
        self.store.rollback()
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM patch")

    def test_drop(self):
        """
        L{Schema.drop} can be used to drop the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])
        self.schema.drop(self.store)
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")

    def test_delete(self):
        """
        L{Schema.delete} can be used to clear the tables of a L{Store}.
        """
        self.schema.create(self.store)
        self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane")])
        self.schema.delete(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_creates_schema(self):
        """
        L{Schema.upgrade} creates a schema from scratch if no exist, and is
        effectively equivalent to L{Schema.create} in such case.
        """
        self.assertRaises(StormError, self.store.execute,
                          "SELECT * FROM person")
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM person")), [])

    def test_upgrade_marks_patches_applied(self):
        """
        L{Schema.upgrade} updates the patch table after applying the needed
        patches.
        """
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        statement = "SELECT * FROM patch"
        self.assertRaises(StormError, self.store.execute, statement)
        self.schema.upgrade(self.store)
        self.assertEquals(list(self.store.execute("SELECT * FROM patch")),
                          [(1, )])

    def test_upgrade_applies_patches(self):
        """
        L{Schema.upgrade} executes the needed patches, that typically modify
        the existing schema.
        """
        self.schema.create(self.store)
        contents = """
def apply(store):
    store.execute('ALTER TABLE person ADD COLUMN phone TEXT')
"""
        self.package.create_module("patch_1.py", contents)
        self.schema.upgrade(self.store)
        self.store.execute(
            "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')")
        self.assertEquals(list(self.store.execute("SELECT * FROM person")),
                          [(1, u"Jane", u"123")])
示例#10
0
class StormManager(Singleton):
    log = logging.getLogger('{}.StormManager'.format(__name__))

    def __init__(self):
        pass

    @loggingInfo
    def init(self, *args):
        self.dbOK = False
        self.openDB()

    @loggingInfo
    def reset(self):
        self.closeDB()
        self.openDB()

    @loggingInfo
    def openDB(self):
        try:
            self._config = ConfigManager()
            self.db = self._config.config[self._config.database]["database"]
            create_db = False
            if self.db == self._config.Sqlite:
                folder = self._config.config[self._config.database]["folder"]
                loc = folder + '/icepapcms.db'
                print("Using Sqlite database at %s" % loc)
                create_db = not os.path.exists(loc)
                if create_db:
                    print("No database file found, creating it")
                    if not os.path.exists(folder):
                        os.mkdir(folder)
                self._database = create_database("%s:%s" % (self.db, loc))
            else:
                server = self._config.config[self._config.database]["server"]
                user = self._config.config[self._config.database]["user"]
                pwd = self._config.config[self._config.database]["password"]
                scheme = "{}://{}:{}@{}/icepapcms".format(
                    self.db, user, pwd, server)

                if self.db == 'mysql':
                    self._database = MySQL(scheme)
                else:
                    self._database = create_database(scheme)

            self._store = Store(self._database)
            if create_db:
                self.dbOK = self.createSqliteDB()
            else:
                self.dbOK = True
        except Exception as e:
            self.log.error("Unexpected error on openDB: %s", e)
            self.dbOK = False

    @loggingInfo
    def createSqliteDB(self):
        try:
            sql_file = resource_filename('icepapcms.db', 'creates_sqlite.sql')
            with open(sql_file) as f:
                sql_script = f.read()
            statements = re.compile(r";[ \t]*$", re.M)

            for statement in statements.split(sql_script):
                # Remove any comments from the file
                statement = re.sub(r"--.*[\n\\Z]", "", statement)
                if statement.strip():
                    create = statement + ";"
                    self._store.execute(create)
            self._store.commit()
            return True
        except Exception as e:
            self.log.error("Unexpected error on createSqliteDB: %s", e)
            return False

    @loggingInfo
    def closeDB(self):
        try:
            if self.dbOK:
                self._store.close()
            return True
        except Exception as e:
            self.log.error("Unexpected error on closeDB:", e)
            self.dbOK = False
            return False

    @loggingInfo
    def store(self, obj):
        self._store.add(obj)

    @loggingInfo
    def remove(self, obj):
        self._store.remove(obj)

    @loggingInfo
    def addIcepapSystem(self, icepap_system):
        try:
            self._store.add(icepap_system)
            self.commitTransaction()
            return True
        except Exception as e:
            self.log.error(
                "some exception trying to store the icepap system "
                "%s: %s", icepap_system, e)
            return False

    @loggingInfo
    def deleteLocation(self, location):
        if self.db == self._config.Sqlite:
            for system in location.systems:
                self.deleteIcepapSystem(system)
        self._store.remove(location)
        self.commitTransaction()

    @loggingInfo
    def deleteIcepapSystem(self, icepap_system):
        if self.db == self._config.Sqlite:
            for driver in icepap_system.drivers:
                self.deleteDriver(driver)
        self._store.remove(icepap_system)
        self.commitTransaction()

    @loggingInfo
    def deleteDriver(self, driver):

        for cfg in driver.historic_cfgs:
            for par in cfg.parameters:
                self._store.remove(par)
            self._store.remove(cfg)
        self._store.remove(driver)
        self.commitTransaction()

    @loggingInfo
    def getAllLocations(self):
        try:
            locations = self._store.find(Location)
            location_dict = {}
            for location in locations:
                location_dict[location.name] = location
            return location_dict
        except Exception as e:
            self.log.error("Unexpected error on getAllLocations: %s", e)
            return {}

    @loggingInfo
    def getLocation(self, name):
        return self._store.get(Location, name)

    @loggingInfo
    def getIcepapSystem(self, icepap_name):
        return self._store.get(IcepapSystem, icepap_name)

    @loggingInfo
    def existsDriver(self, mydriver, id):

        drivers = self._store.find(
            IcepapDriver, IcepapDriver.addr == IcepapDriverCfg.driver_addr,
            IcepapDriverCfg.id == CfgParameter.cfg_id,
            CfgParameter.name == str("ID"), CfgParameter.value == id)
        if drivers:
            for driver in drivers:
                if driver.addr != mydriver.addr:
                    return driver
            return None
        else:
            return None

    @loggingInfo
    def getLocationIcepapSystem(self, location):
        try:
            icepaps = self._store.find(IcepapSystem,
                                       IcepapSystem.location_name == location)
            icepaps.order_by(IcepapSystem.name)
            ipapdict = {}
            for ipap_sys in icepaps:
                ipapdict[ipap_sys.name] = ipap_sys
            return ipapdict
        except Exception as e:
            self.log.error(
                "Unexpected error on getLocationIcepapSystem: "
                "%s", e)
            return {}

    @loggingInfo
    def rollback(self):
        self._store.rollback()

    @loggingInfo
    def commitTransaction(self):
        try:
            self._store.commit()
            return True
        except Exception:
            return False