def perform_data_update(dbfile): new_tmp_store = Store(create_database(GLSettings.make_db_uri(dbfile))) try: db_perform_data_update(new_tmp_store) new_tmp_store.commit() except: new_tmp_store.rollback() raise finally: new_tmp_store.close()
def rollback(self): if self.transaction: self.transaction = None return self.rollback.commit() result = Store.rollback(self) #Store.reset(self) return result
class SchemaTest(MockerTestCase): def setUp(self): super(SchemaTest, self).setUp() self.database = create_database("sqlite:///%s" % self.makeFile()) self.store = Store(self.database) self._package_dirs = set() self._package_names = set() self.package = self.create_package(self.makeDir(), "patch_package") import patch_package creates = ["CREATE TABLE person (id INTEGER, name TEXT)"] drops = ["DROP TABLE person"] deletes = ["DELETE FROM person"] self.schema = Schema(creates, drops, deletes, patch_package) def tearDown(self): for package_dir in self._package_dirs: sys.path.remove(package_dir) for name in list(sys.modules): if name in self._package_names: del sys.modules[name] elif filter( None, [name.startswith("%s." % x) for x in self._package_names]): del sys.modules[name] super(SchemaTest, self).tearDown() def create_package(self, base_dir, name, init_module=None): """Create a Python package. Packages created using this method will be removed from L{sys.path} and L{sys.modules} during L{tearDown}. @param package_dir: The directory in which to create the new package. @param name: The name of the package. @param init_module: Optionally, the text to include in the __init__.py file. @return: A L{Package} instance that can be used to create modules. """ package_dir = os.path.join(base_dir, name) self._package_names.add(name) os.makedirs(package_dir) file = open(os.path.join(package_dir, "__init__.py"), "w") if init_module: file.write(init_module) file.close() sys.path.append(base_dir) self._package_dirs.add(base_dir) return Package(package_dir, name) def test_create(self): """ L{Schema.create} can be used to create the tables of a L{Store}. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_create_with_autocommit_off(self): """ L{Schema.autocommit} can be used to turn automatic commits off. """ self.schema.autocommit(False) self.schema.create(self.store) self.store.rollback() self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_drop(self): """ L{Schema.drop} can be used to drop the tables of a L{Store}. """ self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) self.schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") def test_delete(self): """ L{Schema.delete} can be used to clear the tables of a L{Store}. """ self.schema.create(self.store) self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane")]) self.schema.delete(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_creates_schema(self): """ L{Schema.upgrade} creates a schema from scratch if no exist, and is effectively equivalent to L{Schema.create} in such case. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_marks_patches_applied(self): """ L{Schema.upgrade} updates the patch table after applying the needed patches. """ contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) statement = "SELECT * FROM patch" self.assertRaises(StormError, self.store.execute, statement) self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM patch")), [(1,)]) def test_upgrade_applies_patches(self): """ L{Schema.upgrade} executes the needed patches, that typically modify the existing schema. """ self.schema.create(self.store) contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) self.schema.upgrade(self.store) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane", u"123")])
class PatchApplierTest(MockerTestCase): def setUp(self): super(PatchApplierTest, self).setUp() self.patchdir = self.makeDir() self.pkgdir = os.path.join(self.patchdir, "mypackage") os.makedirs(self.pkgdir) f = open(os.path.join(self.pkgdir, "__init__.py"), "w") f.write("shared_data = []") f.close() # Order of creation here is important to try to screw up the # patch ordering, as os.listdir returns in order of mtime (or # something). for pname, data in [("patch_380.py", patch_test_1), ("patch_42.py", patch_test_0)]: self.add_module(pname, data) sys.path.append(self.patchdir) self.filename = self.makeFile() self.uri = "sqlite:///%s" % self.filename self.store = Store(create_database(self.uri)) self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.assertFalse(self.store.get(Patch, (42))) self.assertFalse(self.store.get(Patch, (380))) import mypackage self.mypackage = mypackage self.patch_set = PatchSet(mypackage) # Create another connection just to keep track of the state of the # whole transaction manager. See the assertion functions below. self.another_store = Store(create_database("sqlite:")) self.another_store.execute("CREATE TABLE test (id INT)") self.another_store.commit() self.prepare_for_transaction_check() class Committer(object): def commit(committer): self.store.commit() self.another_store.commit() def rollback(committer): self.store.rollback() self.another_store.rollback() self.committer = Committer() self.patch_applier = PatchApplier(self.store, self.patch_set, self.committer) def tearDown(self): super(PatchApplierTest, self).tearDown() self.committer.rollback() sys.path.remove(self.patchdir) for name in list(sys.modules): if name == "mypackage" or name.startswith("mypackage."): del sys.modules[name] def add_module(self, module_filename, contents): filename = os.path.join(self.pkgdir, module_filename) file = open(filename, "w") file.write(contents) file.close() def remove_all_modules(self): for filename in os.listdir(self.pkgdir): os.unlink(os.path.join(self.pkgdir, filename)) def prepare_for_transaction_check(self): self.another_store.execute("DELETE FROM test") self.another_store.execute("INSERT INTO test VALUES (1)") def assert_transaction_committed(self): self.another_store.rollback() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEquals(result, (1, ), "Transaction manager wasn't committed.") def assert_transaction_aborted(self): self.another_store.commit() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEquals(result, None, "Transaction manager wasn't aborted.") def test_apply(self): """ L{PatchApplier.apply} executes the patch with the given version. """ self.patch_applier.apply(42) x = getattr(self.mypackage, "patch_42").x self.assertEquals(x, 42) self.assertTrue(self.store.get(Patch, (42))) self.assertTrue("mypackage.patch_42" in sys.modules) self.assert_transaction_committed() def test_apply_with_patch_directory(self): """ If the given L{PatchSet} uses sub-level patches, then the L{PatchApplier.apply} method will look at the per-patch directory and apply the relevant sub-level patch. """ path = os.path.join(self.pkgdir, "patch_99") self.makeDir(path=path) self.makeFile(content="", path=os.path.join(path, "__init__.py")) self.makeFile(content=patch_test_0, path=os.path.join(path, "foo.py")) self.patch_set._sub_level = "foo" self.add_module("patch_99/foo.py", patch_test_0) self.patch_applier.apply(99) self.assertTrue(self.store.get(Patch, (99))) def test_apply_all(self): """ L{PatchApplier.apply_all} executes all unapplied patches. """ self.patch_applier.apply_all() self.assertTrue("mypackage.patch_42" in sys.modules) self.assertTrue("mypackage.patch_380" in sys.modules) x = getattr(self.mypackage, "patch_42").x y = getattr(self.mypackage, "patch_380").y self.assertEquals(x, 42) self.assertEquals(y, 380) self.assert_transaction_committed() def test_apply_exploding_patch(self): """ L{PatchApplier.apply} aborts the transaction if the patch fails. """ self.remove_all_modules() self.add_module("patch_666.py", patch_explosion) self.assertRaises(StormError, self.patch_applier.apply, 666) self.assert_transaction_aborted() def test_wb_apply_all_exploding_patch(self): """ When a patch explodes the store is rolled back to make sure that any changes the patch made to the database are removed. Any other patches that have been applied successfully before it should not be rolled back. Any patches pending after the exploding patch should remain unapplied. """ self.add_module("patch_666.py", patch_explosion) self.add_module("patch_667.py", patch_after_explosion) self.assertEquals(list(self.patch_applier.get_unapplied_versions()), [42, 380, 666, 667]) self.assertRaises(StormError, self.patch_applier.apply_all) self.assertEquals(list(self.patch_applier.get_unapplied_versions()), [666, 667]) def test_mark_applied(self): """ L{PatchApplier.mark} marks a patch has applied by inserting a new row in the patch table. """ self.patch_applier.mark_applied(42) self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertFalse(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_mark_applied_all(self): """ L{PatchApplier.mark_applied_all} marks all pending patches as applied. """ self.patch_applier.mark_applied_all() self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertTrue(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_application_order(self): """ L{PatchApplier.apply_all} applies the patches in increasing version order. """ self.patch_applier.apply_all() self.assertEquals(self.mypackage.shared_data, [42, 380]) def test_has_pending_patches(self): """ L{PatchApplier.has_pending_patches} returns C{True} if there are patches to be applied, C{False} otherwise. """ self.assertTrue(self.patch_applier.has_pending_patches()) self.patch_applier.apply_all() self.assertFalse(self.patch_applier.has_pending_patches()) def test_abort_if_unknown_patches(self): """ L{PatchApplier.mark_applied} raises and error if the patch table contains patches without a matching file in the patch module. """ self.patch_applier.mark_applied(381) self.assertRaises(UnknownPatchError, self.patch_applier.apply_all) def test_get_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns the versions of all unapplied patches. """ patches = [Patch(42), Patch(380), Patch(381)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual(set([381]), patch_applier.get_unknown_patch_versions()) def test_no_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns an empty set if no patches are unapplied. """ patches = [Patch(42), Patch(380)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual(set(), patch_applier.get_unknown_patch_versions()) def test_patch_with_incorrect_apply(self): """ L{PatchApplier.apply_all} raises an error as soon as one of the patches to be applied fails. """ self.add_module("patch_999.py", patch_no_args_apply) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("takes no arguments" in str(e)) self.assertTrue("TypeError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_with_missing_apply(self): """ L{PatchApplier.apply_all} raises an error if one of the patches to to be applied has no 'apply' function defined. """ self.add_module("patch_999.py", patch_missing_apply) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("no attribute" in str(e)) self.assertTrue("AttributeError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_with_syntax_error(self): """ L{PatchApplier.apply_all} raises an error if one of the patches to to be applied contains a syntax error. """ self.add_module("patch_999.py", "that's not python") try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue(" 999 " in str(e)) self.assertTrue("SyntaxError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_error_includes_traceback(self): """ The exception raised by L{PatchApplier.apply_all} when a patch fails include the relevant traceback from the patch. """ self.add_module("patch_999.py", patch_name_error) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("NameError" in str(e)) self.assertTrue("blah" in str(e)) formatted = traceback.format_exc() self.assertTrue("# Comment" in formatted) else: self.fail("BadPatchError not raised")
class SchemaTest(MockerTestCase): def setUp(self): super(SchemaTest, self).setUp() self.database = create_database("sqlite:///%s" % self.makeFile()) self.store = Store(self.database) self._package_dirs = set() self._package_names = set() self.package = self.create_package(self.makeDir(), "patch_package") import patch_package creates = ["CREATE TABLE person (id INTEGER, name TEXT)"] drops = ["DROP TABLE person"] deletes = ["DELETE FROM person"] self.schema = Schema(creates, drops, deletes, patch_package) def tearDown(self): for package_dir in self._package_dirs: sys.path.remove(package_dir) for name in list(sys.modules): if name in self._package_names: del sys.modules[name] elif [ _f for _f in [name.startswith("%s." % x) for x in self._package_names] if _f ]: del sys.modules[name] super(SchemaTest, self).tearDown() def create_package(self, base_dir, name, init_module=None): """Create a Python package. Packages created using this method will be removed from L{sys.path} and L{sys.modules} during L{tearDown}. @param package_dir: The directory in which to create the new package. @param name: The name of the package. @param init_module: Optionally, the text to include in the __init__.py file. @return: A L{Package} instance that can be used to create modules. """ package_dir = os.path.join(base_dir, name) self._package_names.add(name) os.makedirs(package_dir) file = open(os.path.join(package_dir, "__init__.py"), "w") if init_module: file.write(init_module) file.close() sys.path.append(base_dir) self._package_dirs.add(base_dir) return Package(package_dir, name) def test_check_with_missing_schema(self): """ L{Schema.check} raises an exception if the given store is completely pristine and no schema has been applied yet. The transaction doesn't get rolled back so it's still usable. """ self.store.execute("CREATE TABLE foo (bar INT)") self.assertRaises(SchemaMissingError, self.schema.check, self.store) self.assertIsNone(self.store.execute("SELECT 1 FROM foo").get_one()) def test_check_with_unapplied_patches(self): """ L{Schema.check} raises an exception if the given store has unapplied schema patches. """ self.schema.create(self.store) contents = """ def apply(store): pass """ self.package.create_module("patch_1.py", contents) self.assertRaises(UnappliedPatchesError, self.schema.check, self.store) def test_create(self): """ L{Schema.create} can be used to create the tables of a L{Store}. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) # By default changes are committed store2 = Store(self.database) self.assertEquals(list(store2.execute("SELECT * FROM person")), []) def test_create_with_autocommit_off(self): """ L{Schema.autocommit} can be used to turn automatic commits off. """ self.schema.autocommit(False) self.schema.create(self.store) self.store.rollback() self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_drop(self): """ L{Schema.drop} can be used to drop the tables of a L{Store}. """ self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) self.schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") def test_drop_with_missing_patch_table(self): """ L{Schema.drop} works fine even if the user's supplied statements end up dropping the patch table that we created. """ import patch_package schema = Schema([], ["DROP TABLE patch"], [], patch_package) schema.create(self.store) schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_delete(self): """ L{Schema.delete} can be used to clear the tables of a L{Store}. """ self.schema.create(self.store) self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane")]) self.schema.delete(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_creates_schema(self): """ L{Schema.upgrade} creates a schema from scratch if no exist, and is effectively equivalent to L{Schema.create} in such case. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_marks_patches_applied(self): """ L{Schema.upgrade} updates the patch table after applying the needed patches. """ contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) statement = "SELECT * FROM patch" self.assertRaises(StormError, self.store.execute, statement) self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM patch")), [(1, )]) def test_upgrade_applies_patches(self): """ L{Schema.upgrade} executes the needed patches, that typically modify the existing schema. """ self.schema.create(self.store) contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) self.schema.upgrade(self.store) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane", u"123")]) def test_advance(self): """ L{Schema.advance} executes the given patch version. """ self.schema.create(self.store) contents1 = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ contents2 = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN address TEXT') """ self.package.create_module("patch_1.py", contents1) self.package.create_module("patch_2.py", contents2) self.schema.advance(self.store, 1) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane", u"123")])
def db_update_schema(database): """ Check for pending database schema updates. If any are found, apply them and bump the version. """ # Connect to the database db_store = Store(database) # Check if the DB schema has been loaded db_exists = False try: db_store.execute(Select(DBSchema.version)) db_exists = True except: db_store.rollback() logging.debug("Failed to query schema table.") if not db_exists: logging.info("Creating database") schema_file = sorted(glob.glob("schema/schema-*.sql"))[-1] schema_version = schema_file.split(".")[0].split("-")[-1] logging.debug("Using '%s' to deploy schema '%s'" % (schema_file, schema_version)) with open(schema_file, "r") as fd: try: for line in fd.read().replace("\n", "").split(";"): if not line: continue db_store.execute("%s;" % line) db_commit(db_store) logging.info("Database created") except: logging.critical("Failed to initialize the database") return False # Get schema version version = db_store.execute(Select(Max(DBSchema.version))).get_one()[0] if not version: logging.critical("No schema version.") return False # Apply updates for update_file in sorted(glob.glob("schema/update-*.sql")): update_version = update_file.split(".")[0].split("-")[-1] if int(update_version) > version: logging.info("Using '%s' to deploy update '%s'" % (update_file, update_version)) with open(update_file, "r") as fd: try: for line in fd.read().replace("\n", "").split(";"): if not line: continue db_store.execute("%s;" % line) db_commit(db_store) except: logging.critical("Failed to load schema update") return False # Get schema version new_version = db_store.execute(Select(Max(DBSchema.version))).get_one()[0] if new_version > version: logging.info("Database schema successfuly updated from '%s' to '%s'" % (version, new_version)) db_store.close()
class TestChangeTracker(object): class A(object): __storm_table__ = 'testob' changehistory = ChangeHistory.configure("history") clt = ChangeTracker(changehistory) id = Int(primary=1) textval = Unicode(validator=clt) intval = Int(validator=clt) def setUp(self): database = create_database('sqlite:') self.store = Store(database) self.store.execute(""" CREATE table history ( id INTEGER PRIMARY KEY AUTOINCREMENT, ref_class VARCHAR(200), ref_pk VARCHAR(200), ref_attr VARCHAR(200), new_value VARCHAR(200), ctime DATETIME, cuser INT ) """) self.store.execute(""" CREATE TABLE testob ( id INTEGER PRIMARY KEY AUTOINCREMENT, textval VARCHAR(200), intval INT, dateval DATETIME )""") def tearDown(self): self.store.rollback() def test_calls_next_validator(self): clt = ChangeTracker(ChangeHistory.configure("history"), next_validator = lambda ob, attr, v: v*2) class B(self.A): textval = Unicode(validator=clt) b = B() b.textval = u'bork' assert b.textval == u'borkbork' def test_adds_log_entries(self): class B(self.A): clt = ChangeTracker(ChangeHistory.configure("history")) textval = Unicode(validator=clt) b = self.store.add(B()) b.textval = u'pointless' b.textval = u'aimless' changes = list(self.store.find(b.clt.change_cls)) assert_equal(len(changes), 2) assert_equal(changes[0].new_value, 'pointless') assert_equal(changes[1].new_value, 'aimless') def test_value_type_preserved(self): a = self.store.add(self.A()) a.textval = u'one' a.intval = 1 changes = list(self.store.find(a.clt.change_cls)) assert_equal(type(changes[0].new_value), unicode) assert_equal(type(changes[1].new_value), int) def test_ctime_set(self): start = datetime.now() a = self.store.add(self.A()) a.textval = u'x' changes = list(self.store.find(a.clt.change_cls)) assert_equal(type(changes[0].ctime), datetime) assert start < changes[0].ctime < datetime.now() def test_cuser_set(self): def getuser(): return u'Fred' history = ChangeHistory.configure("history", getuser=getuser, usertype=Unicode) class B(self.A): textval = Unicode(validator=ChangeTracker(history)) b = self.store.add(B()) b.textval = u'foo' changes = self.store.find(history) assert_equal(changes[0].cuser, u'Fred') def test_changes_for_returns_change_history(self): a = self.store.add(self.A()) b = self.store.add(self.A()) a.id = 1 a.textval = u'one' a.textval = u'two' b.id = 2 b.textval = u'ein' b.textval = u'zwei' assert_equal([c.new_value for c in a.changehistory.changes_for(a)], [u'one', u'two']) assert_equal([c.new_value for c in a.changehistory.changes_for(b)], [u'ein', u'zwei'])
class PatchTest(MockerTestCase): def setUp(self): super(PatchTest, self).setUp() self.patchdir = self.makeDir() self.pkgdir = os.path.join(self.patchdir, "mypackage") os.makedirs(self.pkgdir) f = open(os.path.join(self.pkgdir, "__init__.py"), "w") f.write("shared_data = []") f.close() # Order of creation here is important to try to screw up the # patch ordering, as os.listdir returns in order of mtime (or # something). for pname, data in [("patch_380.py", patch_test_1), ("patch_42.py", patch_test_0)]: self.add_module(pname, data) sys.path.append(self.patchdir) self.filename = self.makeFile() self.uri = "sqlite:///%s" % self.filename self.store = Store(create_database(self.uri)) self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.assertFalse(self.store.get(Patch, (42))) self.assertFalse(self.store.get(Patch, (380))) import mypackage self.mypackage = mypackage # Create another connection just to keep track of the state of the # whole transaction manager. See the assertion functions below. self.another_store = Store(create_database("sqlite:")) self.another_store.execute("CREATE TABLE test (id INT)") self.another_store.commit() self.prepare_for_transaction_check() class Committer(object): def commit(committer): self.store.commit() self.another_store.commit() def rollback(committer): self.store.rollback() self.another_store.rollback() self.committer = Committer() self.patch_applier = PatchApplier(self.store, self.mypackage, self.committer) def tearDown(self): super(PatchTest, self).tearDown() self.committer.rollback() sys.path.remove(self.patchdir) for name in list(sys.modules): if name == "mypackage" or name.startswith("mypackage."): del sys.modules[name] def add_module(self, module_filename, contents): filename = os.path.join(self.pkgdir, module_filename) file = open(filename, "w") file.write(contents) file.close() def remove_all_modules(self): for filename in os.listdir(self.pkgdir): os.unlink(os.path.join(self.pkgdir, filename)) def prepare_for_transaction_check(self): self.another_store.execute("DELETE FROM test") self.another_store.execute("INSERT INTO test VALUES (1)") def assert_transaction_committed(self): self.another_store.rollback() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEquals(result, (1,), "Transaction manager wasn't committed.") def assert_transaction_aborted(self): self.another_store.commit() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEquals(result, None, "Transaction manager wasn't aborted.") def test_apply(self): """ L{PatchApplier.apply} executes the patch with the given version. """ self.patch_applier.apply(42) x = getattr(self.mypackage, "patch_42").x self.assertEquals(x, 42) self.assertTrue(self.store.get(Patch, (42))) self.assertTrue("mypackage.patch_42" in sys.modules) self.assert_transaction_committed() def test_apply_all(self): """ L{PatchApplier.apply_all} executes all unapplied patches. """ self.patch_applier.apply_all() self.assertTrue("mypackage.patch_42" in sys.modules) self.assertTrue("mypackage.patch_380" in sys.modules) x = getattr(self.mypackage, "patch_42").x y = getattr(self.mypackage, "patch_380").y self.assertEquals(x, 42) self.assertEquals(y, 380) self.assert_transaction_committed() def test_apply_exploding_patch(self): """ L{PatchApplier.apply} aborts the transaction if the patch fails. """ self.remove_all_modules() self.add_module("patch_666.py", patch_explosion) self.assertRaises(StormError, self.patch_applier.apply, 666) self.assert_transaction_aborted() def test_wb_apply_all_exploding_patch(self): """ When a patch explodes the store is rolled back to make sure that any changes the patch made to the database are removed. Any other patches that have been applied successfully before it should not be rolled back. Any patches pending after the exploding patch should remain unapplied. """ self.add_module("patch_666.py", patch_explosion) self.add_module("patch_667.py", patch_after_explosion) self.assertEquals(list(self.patch_applier._get_unapplied_versions()), [42, 380, 666, 667]) self.assertRaises(StormError, self.patch_applier.apply_all) self.assertEquals(list(self.patch_applier._get_unapplied_versions()), [666, 667]) def test_mark_applied(self): """ L{PatchApplier.mark} marks a patch has applied by inserting a new row in the patch table. """ self.patch_applier.mark_applied(42) self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertFalse(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_mark_applied_all(self): """ L{PatchApplier.mark_applied_all} marks all pending patches as applied. """ self.patch_applier.mark_applied_all() self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertTrue(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_application_order(self): """ L{PatchApplier.apply_all} applies the patches in increasing version order. """ self.patch_applier.apply_all() self.assertEquals(self.mypackage.shared_data, [42, 380]) def test_has_pending_patches(self): """ L{PatchApplier.has_pending_patches} returns C{True} if there are patches to be applied, C{False} otherwise. """ self.assertTrue(self.patch_applier.has_pending_patches()) self.patch_applier.apply_all() self.assertFalse(self.patch_applier.has_pending_patches()) def test_abort_if_unknown_patches(self): """ L{PatchApplier.mark_applied} raises and error if the patch table contains patches without a matching file in the patch module. """ self.patch_applier.mark_applied(381) self.assertRaises(UnknownPatchError, self.patch_applier.apply_all) def test_get_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns the versions of all unapplied patches. """ patches = [Patch(42), Patch(380), Patch(381)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual(set([381]), patch_applier.get_unknown_patch_versions()) def test_no_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns an empty set if no patches are unapplied. """ patches = [Patch(42), Patch(380)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual(set(), patch_applier.get_unknown_patch_versions()) def test_patch_with_incorrect_apply(self): """ L{PatchApplier.apply_all} raises an error as soon as one of the patches to be applied fails. """ self.add_module("patch_999.py", patch_no_args_apply) try: self.patch_applier.apply_all() except BadPatchError, e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("takes no arguments" in str(e)) self.assertTrue("TypeError" in str(e)) else:
class SchemaTest(MockerTestCase): def setUp(self): super(SchemaTest, self).setUp() self.database = create_database("sqlite:///%s" % self.makeFile()) self.store = Store(self.database) self._package_dirs = set() self._package_names = set() self.package = self.create_package(self.makeDir(), "patch_package") import patch_package creates = ["CREATE TABLE person (id INTEGER, name TEXT)"] drops = ["DROP TABLE person"] deletes = ["DELETE FROM person"] self.schema = Schema(creates, drops, deletes, patch_package) def tearDown(self): for package_dir in self._package_dirs: sys.path.remove(package_dir) for name in list(sys.modules): if name in self._package_names: del sys.modules[name] elif filter( None, [name.startswith("%s." % x) for x in self._package_names]): del sys.modules[name] super(SchemaTest, self).tearDown() def create_package(self, base_dir, name, init_module=None): """Create a Python package. Packages created using this method will be removed from L{sys.path} and L{sys.modules} during L{tearDown}. @param package_dir: The directory in which to create the new package. @param name: The name of the package. @param init_module: Optionally, the text to include in the __init__.py file. @return: A L{Package} instance that can be used to create modules. """ package_dir = os.path.join(base_dir, name) self._package_names.add(name) os.makedirs(package_dir) file = open(os.path.join(package_dir, "__init__.py"), "w") if init_module: file.write(init_module) file.close() sys.path.append(base_dir) self._package_dirs.add(base_dir) return Package(package_dir, name) def test_create(self): """ L{Schema.create} can be used to create the tables of a L{Store}. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_create_with_autocommit_off(self): """ L{Schema.autocommit} can be used to turn automatic commits off. """ self.schema.autocommit(False) self.schema.create(self.store) self.store.rollback() self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_drop(self): """ L{Schema.drop} can be used to drop the tables of a L{Store}. """ self.schema.create(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) self.schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") def test_delete(self): """ L{Schema.delete} can be used to clear the tables of a L{Store}. """ self.schema.create(self.store) self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane")]) self.schema.delete(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_creates_schema(self): """ L{Schema.upgrade} creates a schema from scratch if no exist, and is effectively equivalent to L{Schema.create} in such case. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_marks_patches_applied(self): """ L{Schema.upgrade} updates the patch table after applying the needed patches. """ contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) statement = "SELECT * FROM patch" self.assertRaises(StormError, self.store.execute, statement) self.schema.upgrade(self.store) self.assertEquals(list(self.store.execute("SELECT * FROM patch")), [(1, )]) def test_upgrade_applies_patches(self): """ L{Schema.upgrade} executes the needed patches, that typically modify the existing schema. """ self.schema.create(self.store) contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) self.schema.upgrade(self.store) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEquals(list(self.store.execute("SELECT * FROM person")), [(1, u"Jane", u"123")])
class StormManager(Singleton): log = logging.getLogger('{}.StormManager'.format(__name__)) def __init__(self): pass @loggingInfo def init(self, *args): self.dbOK = False self.openDB() @loggingInfo def reset(self): self.closeDB() self.openDB() @loggingInfo def openDB(self): try: self._config = ConfigManager() self.db = self._config.config[self._config.database]["database"] create_db = False if self.db == self._config.Sqlite: folder = self._config.config[self._config.database]["folder"] loc = folder + '/icepapcms.db' print("Using Sqlite database at %s" % loc) create_db = not os.path.exists(loc) if create_db: print("No database file found, creating it") if not os.path.exists(folder): os.mkdir(folder) self._database = create_database("%s:%s" % (self.db, loc)) else: server = self._config.config[self._config.database]["server"] user = self._config.config[self._config.database]["user"] pwd = self._config.config[self._config.database]["password"] scheme = "{}://{}:{}@{}/icepapcms".format( self.db, user, pwd, server) if self.db == 'mysql': self._database = MySQL(scheme) else: self._database = create_database(scheme) self._store = Store(self._database) if create_db: self.dbOK = self.createSqliteDB() else: self.dbOK = True except Exception as e: self.log.error("Unexpected error on openDB: %s", e) self.dbOK = False @loggingInfo def createSqliteDB(self): try: sql_file = resource_filename('icepapcms.db', 'creates_sqlite.sql') with open(sql_file) as f: sql_script = f.read() statements = re.compile(r";[ \t]*$", re.M) for statement in statements.split(sql_script): # Remove any comments from the file statement = re.sub(r"--.*[\n\\Z]", "", statement) if statement.strip(): create = statement + ";" self._store.execute(create) self._store.commit() return True except Exception as e: self.log.error("Unexpected error on createSqliteDB: %s", e) return False @loggingInfo def closeDB(self): try: if self.dbOK: self._store.close() return True except Exception as e: self.log.error("Unexpected error on closeDB:", e) self.dbOK = False return False @loggingInfo def store(self, obj): self._store.add(obj) @loggingInfo def remove(self, obj): self._store.remove(obj) @loggingInfo def addIcepapSystem(self, icepap_system): try: self._store.add(icepap_system) self.commitTransaction() return True except Exception as e: self.log.error( "some exception trying to store the icepap system " "%s: %s", icepap_system, e) return False @loggingInfo def deleteLocation(self, location): if self.db == self._config.Sqlite: for system in location.systems: self.deleteIcepapSystem(system) self._store.remove(location) self.commitTransaction() @loggingInfo def deleteIcepapSystem(self, icepap_system): if self.db == self._config.Sqlite: for driver in icepap_system.drivers: self.deleteDriver(driver) self._store.remove(icepap_system) self.commitTransaction() @loggingInfo def deleteDriver(self, driver): for cfg in driver.historic_cfgs: for par in cfg.parameters: self._store.remove(par) self._store.remove(cfg) self._store.remove(driver) self.commitTransaction() @loggingInfo def getAllLocations(self): try: locations = self._store.find(Location) location_dict = {} for location in locations: location_dict[location.name] = location return location_dict except Exception as e: self.log.error("Unexpected error on getAllLocations: %s", e) return {} @loggingInfo def getLocation(self, name): return self._store.get(Location, name) @loggingInfo def getIcepapSystem(self, icepap_name): return self._store.get(IcepapSystem, icepap_name) @loggingInfo def existsDriver(self, mydriver, id): drivers = self._store.find( IcepapDriver, IcepapDriver.addr == IcepapDriverCfg.driver_addr, IcepapDriverCfg.id == CfgParameter.cfg_id, CfgParameter.name == str("ID"), CfgParameter.value == id) if drivers: for driver in drivers: if driver.addr != mydriver.addr: return driver return None else: return None @loggingInfo def getLocationIcepapSystem(self, location): try: icepaps = self._store.find(IcepapSystem, IcepapSystem.location_name == location) icepaps.order_by(IcepapSystem.name) ipapdict = {} for ipap_sys in icepaps: ipapdict[ipap_sys.name] = ipap_sys return ipapdict except Exception as e: self.log.error( "Unexpected error on getLocationIcepapSystem: " "%s", e) return {} @loggingInfo def rollback(self): self._store.rollback() @loggingInfo def commitTransaction(self): try: self._store.commit() return True except Exception: return False