def test_remove_columns(): """Test workaround for dropping columns in sqlite3.""" with create_temp_db() as (temp_db, conn): conn.execute( r''' CREATE TABLE foo ( bar, baz, pub ) ''' ) conn.execute( r''' INSERT INTO foo VALUES (?,?,?) ''', ['BAR', 'BAZ', 'PUB'] ) conn.commit() conn.close() dao = CylcSuiteDAO(temp_db) dao.remove_columns('foo', ['bar', 'baz']) conn = dao.connect() data = [row for row in conn.execute(r'SELECT * from foo')] assert data == [('PUB',)]
class TestRunDb(unittest.TestCase): def setUp(self): self.dao = CylcSuiteDAO(':memory:') self.mocked_connection = mock.Mock() self.dao.connect = mock.MagicMock(return_value=self.mocked_connection) get_select_task_job = [ ["cycle", "name", "NN"], ["cycle", "name", None], ["cycle", "name", "02"], ] def test_select_task_job(self): """Test the rundb CylcSuiteDAO select_task_job method""" columns = self.dao.tables[CylcSuiteDAO.TABLE_TASK_JOBS].columns[3:] expected_values = [[2 for _ in columns]] self.mocked_connection.execute.return_value = expected_values # parameterized test for cycle, name, submit_num in self.get_select_task_job: returned_values = self.dao.select_task_job(cycle, name, submit_num) for column in columns: self.assertEqual(2, returned_values[column.name]) def test_select_task_job_sqlite_error(self): """Test that when the rundb CylcSuiteDAO select_task_job method raises a SQLite exception, the method returns None""" self.mocked_connection.execute.side_effect = sqlite3.DatabaseError r = self.dao.select_task_job("it'll", "raise", "an error!") self.assertIsNone(r)
class TestRunDb(unittest.TestCase): def setUp(self): self.dao = CylcSuiteDAO(':memory:') self.mocked_connection = mock.Mock() self.dao.connect = mock.MagicMock(return_value=self.mocked_connection) get_select_task_job = [ ["cycle", "name", "NN"], ["cycle", "name", None], ["cycle", "name", "02"], ] def test_select_task_job(self): """Test the rundb CylcSuiteDAO select_task_job method""" columns = self.dao.tables[CylcSuiteDAO.TABLE_TASK_JOBS].columns[3:] expected_values = [[2 for _ in columns]] self.mocked_connection.execute.return_value = expected_values # parameterized test for cycle, name, submit_num in self.get_select_task_job: returned_values = self.dao.select_task_job(cycle, name, submit_num) for column in columns: self.assertEqual(2, returned_values[column.name]) def test_select_task_job_sqlite_error(self): """Test that when the rundb CylcSuiteDAO select_task_job method raises a SQLite exception, the method returns None""" self.mocked_connection.execute.side_effect = DatabaseError r = self.dao.select_task_job("it'll", "raise", "an error!") self.assertIsNone(r)
def generate_dotcode(): # get markdown with tempfile.NamedTemporaryFile() as tf_db: # is_public=False triggers the creation of tables CylcSuiteDAO(db_file_name=tf_db.name, is_public=False) schema, orphans = schema_to_markdown(db_name=tf_db.name) # graph prefix dotcode = [ 'graph {', 'node [label = "\\N", shape = plaintext];', 'edge [color = gray50, minlen = 2, style = dashed];', 'rankdir = "LR";' ] # the database graph tables, relationships = all_to_intermediary(schema) dotcode.extend([x.to_dot() for x in tables]) dotcode.extend([x.to_dot() for x in relationships]) # group orphan nodes to cut down on clutter dotcode.extend(group_nodes(orphans)) # use invisible graph edges to change the graph layout dotcode.append( '"task_pool_checkpoints" -- "inheritance"[style=invis];') # graph suffix dotcode += ['}'] return dotcode
def on_suite_start(self, is_restart): """Initialise data access objects. Ensure that: * private database file is private * public database is in sync with private database """ if not is_restart: try: os.unlink(self.pri_path) except OSError: # Just in case the path is a directory! rmtree(self.pri_path, ignore_errors=True) self.pri_dao = self.get_pri_dao() os.chmod(self.pri_path, 0o600) self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True) self.copy_pri_to_pub()
def get_task_job_attrs(suite_name, point, task, submit_num): """Return job (platform, job_runner_name, live_job_id). live_job_id is the job ID if job is running, else None. """ suite_dao = CylcSuiteDAO(get_suite_run_pub_db_name(suite_name), is_public=True) task_job_data = suite_dao.select_task_job(point, task, submit_num) suite_dao.close() if task_job_data is None: return (None, None, None) job_runner_name = task_job_data["job_runner_name"] job_id = task_job_data["job_id"] if (not job_runner_name or not job_id or not task_job_data["time_run"] or task_job_data["time_run_exit"]): live_job_id = None else: live_job_id = job_id return (task_job_data["platform_name"], job_runner_name, live_job_id)
def get_task_job_attrs(suite_name, point, task, submit_num): """Return job (user_at_host, batch_sys_name, live_job_id). live_job_id is batch system job ID if job is running, else None. """ suite_dao = CylcSuiteDAO(get_suite_run_pub_db_name(suite_name), is_public=True) task_job_data = suite_dao.select_task_job(point, task, submit_num) suite_dao.close() if task_job_data is None: return (None, None, None) batch_sys_name = task_job_data["batch_sys_name"] batch_sys_job_id = task_job_data["batch_sys_job_id"] if (not batch_sys_name or not batch_sys_job_id or not task_job_data["time_run"] or task_job_data["time_run_exit"]): live_job_id = None else: live_job_id = batch_sys_job_id return (task_job_data["user_at_host"], batch_sys_name, live_job_id)
class SuiteDatabaseManager: """Manage the suite runtime private and public databases.""" KEY_INITIAL_CYCLE_POINT = 'icp' KEY_INITIAL_CYCLE_POINT_COMPATS = (KEY_INITIAL_CYCLE_POINT, 'initial_point') KEY_START_CYCLE_POINT = 'startcp' KEY_START_CYCLE_POINT_COMPATS = (KEY_START_CYCLE_POINT, 'start_point', 'warm_point') KEY_FINAL_CYCLE_POINT = 'fcp' KEY_FINAL_CYCLE_POINT_COMPATS = (KEY_FINAL_CYCLE_POINT, 'final_point') KEY_STOP_CYCLE_POINT = 'stopcp' KEY_UUID_STR = 'uuid_str' KEY_CYLC_VERSION = 'cylc_version' KEY_UTC_MODE = 'UTC_mode' KEY_HOLD = 'is_held' KEY_HOLD_CYCLE_POINT = 'holdcp' KEY_NO_AUTO_SHUTDOWN = 'no_auto_shutdown' KEY_RUN_MODE = 'run_mode' KEY_STOP_CLOCK_TIME = 'stop_clock_time' KEY_STOP_TASK = 'stop_task' KEY_CYCLE_POINT_FORMAT = 'cycle_point_format' KEY_CYCLE_POINT_TIME_ZONE = 'cycle_point_tz' TABLE_BROADCAST_EVENTS = CylcSuiteDAO.TABLE_BROADCAST_EVENTS TABLE_BROADCAST_STATES = CylcSuiteDAO.TABLE_BROADCAST_STATES TABLE_CHECKPOINT_ID = CylcSuiteDAO.TABLE_CHECKPOINT_ID TABLE_INHERITANCE = CylcSuiteDAO.TABLE_INHERITANCE TABLE_SUITE_PARAMS = CylcSuiteDAO.TABLE_SUITE_PARAMS TABLE_SUITE_TEMPLATE_VARS = CylcSuiteDAO.TABLE_SUITE_TEMPLATE_VARS TABLE_TASK_ACTION_TIMERS = CylcSuiteDAO.TABLE_TASK_ACTION_TIMERS TABLE_TASK_POOL = CylcSuiteDAO.TABLE_TASK_POOL TABLE_TASK_OUTPUTS = CylcSuiteDAO.TABLE_TASK_OUTPUTS TABLE_TASK_STATES = CylcSuiteDAO.TABLE_TASK_STATES TABLE_TASK_TIMEOUT_TIMERS = CylcSuiteDAO.TABLE_TASK_TIMEOUT_TIMERS TABLE_XTRIGGERS = CylcSuiteDAO.TABLE_XTRIGGERS TABLE_ABS_OUTPUTS = CylcSuiteDAO.TABLE_ABS_OUTPUTS def __init__(self, pri_d=None, pub_d=None): self.pri_path = None if pri_d: self.pri_path = os.path.join(pri_d, CylcSuiteDAO.DB_FILE_BASE_NAME) self.pub_path = None if pub_d: self.pub_path = os.path.join(pub_d, CylcSuiteDAO.DB_FILE_BASE_NAME) self.pri_dao = None self.pub_dao = None self.db_deletes_map = { self.TABLE_BROADCAST_STATES: [], self.TABLE_SUITE_PARAMS: [], self.TABLE_TASK_POOL: [], self.TABLE_TASK_ACTION_TIMERS: [], self.TABLE_TASK_OUTPUTS: [], self.TABLE_TASK_TIMEOUT_TIMERS: [], self.TABLE_XTRIGGERS: [] } self.db_inserts_map = { self.TABLE_BROADCAST_EVENTS: [], self.TABLE_BROADCAST_STATES: [], self.TABLE_INHERITANCE: [], self.TABLE_SUITE_PARAMS: [], self.TABLE_SUITE_TEMPLATE_VARS: [], self.TABLE_CHECKPOINT_ID: [], self.TABLE_TASK_POOL: [], self.TABLE_TASK_ACTION_TIMERS: [], self.TABLE_TASK_OUTPUTS: [], self.TABLE_TASK_TIMEOUT_TIMERS: [], self.TABLE_XTRIGGERS: [], self.TABLE_ABS_OUTPUTS: [] } self.db_updates_map = {} def checkpoint(self, name): """Checkpoint the task pool, etc.""" return self.pri_dao.take_checkpoints(name, other_daos=[self.pub_dao]) def copy_pri_to_pub(self): """Copy content of primary database file to public database file. Use temporary file to ensure that we do not end up with a partial file. """ temp_pub_db_file_name = None self.pub_dao.close() try: self.pub_dao.conn = None # reset connection open(self.pub_dao.db_file_name, "a").close() # touch st_mode = os.stat(self.pub_dao.db_file_name).st_mode temp_pub_db_file_name = mkstemp( prefix=self.pub_dao.DB_FILE_BASE_NAME, dir=os.path.dirname(self.pub_dao.db_file_name))[1] copy(self.pri_dao.db_file_name, temp_pub_db_file_name) os.rename(temp_pub_db_file_name, self.pub_dao.db_file_name) os.chmod(self.pub_dao.db_file_name, st_mode) except (IOError, OSError): if temp_pub_db_file_name: os.unlink(temp_pub_db_file_name) raise def delete_suite_params(self, *keys): """Schedule deletion of rows from suite_params table by keys.""" for key in keys: self.db_deletes_map[self.TABLE_SUITE_PARAMS].append({'key': key}) def delete_suite_hold(self): """Delete suite hold flag and hold cycle point.""" self.delete_suite_params(self.KEY_HOLD, self.KEY_HOLD_CYCLE_POINT) def delete_suite_stop_clock_time(self): """Delete suite stop clock time from suite_params table.""" self.delete_suite_params(self.KEY_STOP_CLOCK_TIME) def delete_suite_stop_cycle_point(self): """Delete suite stop cycle point from suite_params table.""" self.delete_suite_params(self.KEY_STOP_CYCLE_POINT) def delete_suite_stop_task(self): """Delete suite stop task from suite_params table.""" self.delete_suite_params(self.KEY_STOP_TASK) def get_pri_dao(self): """Return the primary DAO.""" return CylcSuiteDAO(self.pri_path) @staticmethod def _namedtuple2json(obj): """Convert nametuple obj to a JSON string. Arguments: obj (namedtuple): input object to serialize to JSON. Return (str): Serialized JSON string of input object in the form "[type, list]". """ if obj is None: return json.dumps(None) else: return json.dumps([type(obj).__name__, obj.__getnewargs__()]) def on_suite_start(self, is_restart): """Initialise data access objects. Ensure that: * private database file is private * public database is in sync with private database """ if not is_restart: try: os.unlink(self.pri_path) except OSError: # Just in case the path is a directory! rmtree(self.pri_path, ignore_errors=True) self.pri_dao = self.get_pri_dao() os.chmod(self.pri_path, 0o600) self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True) self.copy_pri_to_pub() def on_suite_shutdown(self): """Close data access objects.""" if self.pri_dao: self.pri_dao.close() self.pri_dao = None if self.pub_dao: self.pub_dao.close() self.pub_dao = None def process_queued_ops(self): """Handle queued db operations for each task proxy.""" if self.pri_dao is None: return # Record suite parameters and tasks in pool # Record any broadcast settings to be dumped out if any(self.db_deletes_map.values()): for table_name, db_deletes in sorted(self.db_deletes_map.items()): while db_deletes: where_args = db_deletes.pop(0) self.pri_dao.add_delete_item(table_name, where_args) self.pub_dao.add_delete_item(table_name, where_args) if any(self.db_inserts_map.values()): for table_name, db_inserts in sorted(self.db_inserts_map.items()): while db_inserts: db_insert = db_inserts.pop(0) self.pri_dao.add_insert_item(table_name, db_insert) self.pub_dao.add_insert_item(table_name, db_insert) if (hasattr(self, 'db_updates_map') and any(self.db_updates_map.values())): for table_name, db_updates in sorted(self.db_updates_map.items()): while db_updates: set_args, where_args = db_updates.pop(0) self.pri_dao.add_update_item(table_name, set_args, where_args) self.pub_dao.add_update_item(table_name, set_args, where_args) # Previously, we used a separate thread for database writes. This has # now been removed. For the private database, there is no real # advantage in using a separate thread as it needs to be always in sync # with what is current. For the public database, which does not need to # be fully in sync, there is some advantage of using a separate # thread/process, if writing to it becomes a bottleneck. At the moment, # there is no evidence that this is a bottleneck, so it is better to # keep the logic simple. self.pri_dao.execute_queued_items() self.pub_dao.execute_queued_items() def put_broadcast(self, modified_settings, is_cancel=False): """Put or clear broadcasts in runtime database.""" now = get_current_time_string(display_sub_seconds=True) for broadcast_change in (get_broadcast_change_iter( modified_settings, is_cancel)): broadcast_change["time"] = now self.db_inserts_map[self.TABLE_BROADCAST_EVENTS].append( broadcast_change) if is_cancel: self.db_deletes_map[self.TABLE_BROADCAST_STATES].append({ "point": broadcast_change["point"], "namespace": broadcast_change["namespace"], "key": broadcast_change["key"] }) # Delete statements are currently executed before insert # statements, so we should clear out any insert statements that # are deleted here. # (Not the most efficient logic here, but unless we have a # large number of inserts, then this should not be a big # concern.) inserts = [] for insert in self.db_inserts_map[self.TABLE_BROADCAST_STATES]: if any(insert[key] != broadcast_change[key] for key in ["point", "namespace", "key"]): inserts.append(insert) self.db_inserts_map[self.TABLE_BROADCAST_STATES] = inserts else: self.db_inserts_map[self.TABLE_BROADCAST_STATES].append({ "point": broadcast_change["point"], "namespace": broadcast_change["namespace"], "key": broadcast_change["key"], "value": broadcast_change["value"] }) def put_runtime_inheritance(self, config): """Put task/family inheritance in runtime database.""" for namespace in config.cfg['runtime']: value = config.runtime['linearized ancestors'][namespace] self.db_inserts_map[self.TABLE_INHERITANCE].append({ "namespace": namespace, "inheritance": json.dumps(value) }) def put_suite_params(self, schd): """Put various suite parameters from schd in runtime database. This method queues the relevant insert statements. Arguments: schd (cylc.flow.scheduler.Scheduler): scheduler object. """ self.db_deletes_map[self.TABLE_SUITE_PARAMS].append({}) self.db_inserts_map[self.TABLE_SUITE_PARAMS].extend([ { "key": self.KEY_UUID_STR, "value": str(schd.uuid_str) }, { "key": self.KEY_CYLC_VERSION, "value": CYLC_VERSION }, { "key": self.KEY_UTC_MODE, "value": get_utc_mode() }, ]) if schd.config.cycle_point_dump_format is not None: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": self.KEY_CYCLE_POINT_FORMAT, "value": schd.config.cycle_point_dump_format }) if schd.pool.is_held: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": self.KEY_HOLD, "value": 1 }) for key in (self.KEY_INITIAL_CYCLE_POINT, self.KEY_FINAL_CYCLE_POINT, self.KEY_START_CYCLE_POINT, self.KEY_STOP_CYCLE_POINT, self.KEY_RUN_MODE, self.KEY_CYCLE_POINT_TIME_ZONE): value = getattr(schd.options, key, None) if value is not None: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": key, "value": value }) if schd.options.no_auto_shutdown is not None: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": self.KEY_NO_AUTO_SHUTDOWN, "value": int(schd.options.no_auto_shutdown) }) for key in (self.KEY_STOP_CLOCK_TIME, self.KEY_STOP_TASK): value = getattr(schd, key, None) if value is not None: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": key, "value": value }) def put_suite_params_1(self, key, value): """Queue insertion of 1 key=value pair to the suite_params table.""" self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": key, "value": value }) def put_suite_hold(self): """Put suite hold flag to suite_params table.""" self.put_suite_params_1(self.KEY_HOLD, 1) def put_suite_hold_cycle_point(self, value): """Put suite hold cycle point to suite_params table.""" self.put_suite_params_1(self.KEY_HOLD_CYCLE_POINT, str(value)) def put_suite_stop_clock_time(self, value): """Put suite stop clock time to suite_params table.""" self.put_suite_params_1(self.KEY_STOP_CLOCK_TIME, value) def put_suite_stop_cycle_point(self, value): """Put suite stop cycle point to suite_params table.""" self.put_suite_params_1(self.KEY_STOP_CYCLE_POINT, value) def put_suite_stop_task(self, value): """Put suite stop task to suite_params table.""" self.put_suite_params_1(self.KEY_STOP_TASK, value) def put_suite_template_vars(self, template_vars): """Put template_vars in runtime database. This method queues the relevant insert statements. """ for key, value in template_vars.items(): self.db_inserts_map[self.TABLE_SUITE_TEMPLATE_VARS].append({ "key": key, "value": value }) def put_task_event_timers(self, task_events_mgr): """Put statements to update the task_action_timers table.""" self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({}) for key, timer in task_events_mgr.event_timers.items(): key1, point, name, submit_num = key self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": name, "cycle": point, "ctx_key": json.dumps(( key1, submit_num, )), "ctx": self._namedtuple2json(timer.ctx), "delays": json.dumps(timer.delays), "num": timer.num, "delay": timer.delay, "timeout": timer.timeout }) def put_xtriggers(self, sat_xtrig): """Put statements to update external triggers table.""" self.db_deletes_map[self.TABLE_XTRIGGERS].append({}) for sig, res in sat_xtrig.items(): self.db_inserts_map[self.TABLE_XTRIGGERS].append({ "signature": sig, "results": json.dumps(res) }) def put_update_task_state(self, itask): """Update task_states table for current state of itask. For final event-driven update before removing finished tasks. No need to update task_pool table as finished tasks are immediately removed from the pool. """ set_args = { "time_updated": itask.state.time_updated, "status": itask.state.status } where_args = { "cycle": str(itask.point), "name": itask.tdef.name, "flow_label": itask.flow_label, "submit_num": itask.submit_num, } self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args)) def put_task_pool(self, pool): """Update various task tables for current pool, in runtime database. Queue delete (everything) statements to wipe the tables, and queue the relevant insert statements for the current tasks in the pool. """ self.db_deletes_map[self.TABLE_TASK_POOL].append({}) # No need to do: # self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({}) # Should already be done by self.put_task_event_timers above. self.db_deletes_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({}) for itask in pool.get_all_tasks(): satisfied = {} for p in itask.state.prerequisites: for k, v in p.satisfied.items(): # need string key, not tuple for json.dumps satisfied[json.dumps(k)] = v self.db_inserts_map[self.TABLE_TASK_POOL].append({ "name": itask.tdef.name, "cycle": str(itask.point), "flow_label": itask.flow_label, "status": itask.state.status, "satisfied": json.dumps(satisfied), "is_held": itask.state.is_held }) if itask.timeout is not None: self.db_inserts_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "timeout": itask.timeout }) if itask.poll_timer is not None: self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "ctx_key": json.dumps("poll_timer"), "ctx": self._namedtuple2json(itask.poll_timer.ctx), "delays": json.dumps(itask.poll_timer.delays), "num": itask.poll_timer.num, "delay": itask.poll_timer.delay, "timeout": itask.poll_timer.timeout }) for ctx_key_1, timer in itask.try_timers.items(): if timer is None: continue self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "ctx_key": json.dumps(("try_timers", ctx_key_1)), "ctx": self._namedtuple2json(timer.ctx), "delays": json.dumps(timer.delays), "num": timer.num, "delay": timer.delay, "timeout": timer.timeout }) if itask.state.time_updated: set_args = { "time_updated": itask.state.time_updated, "submit_num": itask.submit_num, "try_num": itask.get_try_num(), "status": itask.state.status } where_args = { "cycle": str(itask.point), "name": itask.tdef.name, "flow_label": itask.flow_label } self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args)) itask.state.time_updated = None self.db_inserts_map[self.TABLE_CHECKPOINT_ID].append({ # id = -1 for latest "id": CylcSuiteDAO.CHECKPOINT_LATEST_ID, "time": get_current_time_string(), "event": CylcSuiteDAO.CHECKPOINT_LATEST_EVENT }) def put_insert_task_events(self, itask, args): """Put INSERT statement for task_events table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_EVENTS, itask, args) def put_insert_task_late_flags(self, itask): """If itask is late, put INSERT statement to task_late_flags table.""" if itask.is_late: self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_LATE_FLAGS, itask, {"value": True}) def put_insert_task_jobs(self, itask, args): """Put INSERT statement for task_jobs table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, args) def put_insert_task_states(self, itask, args): """Put INSERT statement for task_states table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_STATES, itask, args) def put_insert_task_outputs(self, itask): """Reset custom outputs for a task.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {}) def put_insert_abs_output(self, cycle, name, output): """Put INSERT statement for a new abs output.""" args = {"cycle": str(cycle), "name": name, "output": output} self.db_inserts_map.setdefault(CylcSuiteDAO.TABLE_ABS_OUTPUTS, []) self.db_inserts_map[CylcSuiteDAO.TABLE_ABS_OUTPUTS].append(args) def _put_insert_task_x(self, table_name, itask, args): """Put INSERT statement for a task_* table.""" args.update({"name": itask.tdef.name, "cycle": str(itask.point)}) if "submit_num" not in args: args["submit_num"] = itask.submit_num self.db_inserts_map.setdefault(table_name, []) self.db_inserts_map[table_name].append(args) def put_update_task_jobs(self, itask, set_args): """Put UPDATE statement for task_jobs table.""" self._put_update_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, set_args) def put_update_task_outputs(self, itask): """Put UPDATE statement for task_outputs table.""" items = {} for trigger, message in itask.state.outputs.get_completed_customs(): items[trigger] = message self._put_update_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {"outputs": json.dumps(items)}) def _put_update_task_x(self, table_name, itask, set_args): """Put UPDATE statement for a task_* table.""" where_args = {"cycle": str(itask.point), "name": itask.tdef.name} if "submit_num" not in set_args: where_args["submit_num"] = itask.submit_num if "flow_label" not in set_args: where_args["flow_label"] = itask.flow_label self.db_updates_map.setdefault(table_name, []) self.db_updates_map[table_name].append((set_args, where_args)) def recover_pub_from_pri(self): """Recover public database from private database.""" if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES: self.copy_pri_to_pub() LOG.warning( "%(pub_db_name)s: recovered from %(pri_db_name)s" % { "pub_db_name": self.pub_dao.db_file_name, "pri_db_name": self.pri_dao.db_file_name }) self.pub_dao.n_tries = 0 def restart_upgrade(self): """Vacuum/upgrade runtime DB on restart.""" pri_dao = self.get_pri_dao() pri_dao.vacuum() # compat: <8.0 pri_dao.upgrade_is_held() pri_dao.upgrade_retry_state() pri_dao.close()
def get_pri_dao(self): """Return the primary DAO.""" return CylcSuiteDAO(self.pri_path)
def setUp(self): self.dao = CylcSuiteDAO(':memory:') self.mocked_connection = mock.Mock() self.dao.connect = mock.MagicMock(return_value=self.mocked_connection)
def test_upgrade_hold_swap(): """Pre Cylc8 DB upgrade compatibility test.""" # test data initial_data = [ # (name, cycle, status, hold_swap) ('foo', '1', 'waiting', ''), ('bar', '1', 'held', 'waiting'), ('baz', '1', 'held', 'running'), ('pub', '1', 'waiting', 'held') ] expected_data = [ # (name, cycle, status, hold_swap, is_held) ('foo', '1', 'waiting', 0), ('bar', '1', 'waiting', 1), ('baz', '1', 'running', 1), ('pub', '1', 'waiting', 1) ] tables = [ CylcSuiteDAO.TABLE_TASK_POOL, CylcSuiteDAO.TABLE_TASK_POOL_CHECKPOINTS ] with create_temp_db() as (temp_db, conn): # initialise tables for table in tables: conn.execute( rf''' CREATE TABLE {table} ( name varchar(255), cycle varchar(255), status varchar(255), hold_swap varchar(255) ) ''' ) conn.executemany( rf''' INSERT INTO {table} VALUES (?,?,?,?) ''', initial_data ) # close database conn.commit() conn.close() # open database as cylc dao dao = CylcSuiteDAO(temp_db) conn = dao.connect() # check the initial data was correctly inserted for table in tables: dump = [x for x in conn.execute(rf'SELECT * FROM {table}')] assert dump == initial_data # upgrade assert dao.upgrade_is_held() # check the data was correctly upgraded for table in tables: dump = [x for x in conn.execute(r'SELECT * FROM task_pool')] assert dump == expected_data # make sure the upgrade is skipped on future runs assert not dao.upgrade_is_held()
def test_upgrade_retry_state(): """Pre Cylc8 DB upgrade compatibility test.""" initial_data = [ # (name, cycle, status) ('foo', '1', 'waiting'), ('bar', '1', 'running'), ('baz', '1', 'retrying'), ('pub', '1', 'submit-retrying') ] expected_data = [ # (name, cycle, status) ('foo', '1', 'waiting'), ('bar', '1', 'running'), ('baz', '1', 'waiting'), ('pub', '1', 'waiting') ] tables = [ CylcSuiteDAO.TABLE_TASK_POOL, CylcSuiteDAO.TABLE_TASK_POOL_CHECKPOINTS ] with create_temp_db() as (temp_db, conn): # initialise tables for table in tables: conn.execute( rf''' CREATE TABLE {table} ( name varchar(255), cycle varchar(255), status varchar(255) ) ''' ) conn.executemany( rf''' INSERT INTO {table} VALUES (?,?,?) ''', initial_data ) # close database conn.commit() conn.close() # open database as cylc dao dao = CylcSuiteDAO(temp_db) conn = dao.connect() # check the initial data was correctly inserted for table in tables: dump = [x for x in conn.execute(rf'SELECT * FROM {table}')] assert dump == initial_data # upgrade assert dao.upgrade_retry_state() == [ ('1', 'baz', 'retrying'), ('1', 'pub', 'submit-retrying') ] # check the data was correctly upgraded for table in tables: dump = [x for x in conn.execute(r'SELECT * FROM task_pool')] assert dump == expected_data
class SuiteDatabaseManager(object): """Manage the suite runtime private and public databases.""" TABLE_BROADCAST_EVENTS = CylcSuiteDAO.TABLE_BROADCAST_EVENTS TABLE_BROADCAST_STATES = CylcSuiteDAO.TABLE_BROADCAST_STATES TABLE_CHECKPOINT_ID = CylcSuiteDAO.TABLE_CHECKPOINT_ID TABLE_INHERITANCE = CylcSuiteDAO.TABLE_INHERITANCE TABLE_SUITE_PARAMS = CylcSuiteDAO.TABLE_SUITE_PARAMS TABLE_SUITE_TEMPLATE_VARS = CylcSuiteDAO.TABLE_SUITE_TEMPLATE_VARS TABLE_TASK_ACTION_TIMERS = CylcSuiteDAO.TABLE_TASK_ACTION_TIMERS TABLE_TASK_POOL = CylcSuiteDAO.TABLE_TASK_POOL TABLE_TASK_OUTPUTS = CylcSuiteDAO.TABLE_TASK_OUTPUTS TABLE_TASK_STATES = CylcSuiteDAO.TABLE_TASK_STATES TABLE_TASK_TIMEOUT_TIMERS = CylcSuiteDAO.TABLE_TASK_TIMEOUT_TIMERS TABLE_XTRIGGERS = CylcSuiteDAO.TABLE_XTRIGGERS def __init__(self, pri_d=None, pub_d=None): self.pri_path = None if pri_d: self.pri_path = os.path.join(pri_d, CylcSuiteDAO.DB_FILE_BASE_NAME) self.pub_path = None if pub_d: self.pub_path = os.path.join(pub_d, CylcSuiteDAO.DB_FILE_BASE_NAME) self.pri_dao = None self.pub_dao = None self.db_deletes_map = { self.TABLE_BROADCAST_STATES: [], self.TABLE_SUITE_PARAMS: [], self.TABLE_TASK_POOL: [], self.TABLE_TASK_ACTION_TIMERS: [], self.TABLE_TASK_OUTPUTS: [], self.TABLE_TASK_TIMEOUT_TIMERS: [], self.TABLE_XTRIGGERS: []} self.db_inserts_map = { self.TABLE_BROADCAST_EVENTS: [], self.TABLE_BROADCAST_STATES: [], self.TABLE_INHERITANCE: [], self.TABLE_SUITE_PARAMS: [], self.TABLE_SUITE_TEMPLATE_VARS: [], self.TABLE_CHECKPOINT_ID: [], self.TABLE_TASK_POOL: [], self.TABLE_TASK_ACTION_TIMERS: [], self.TABLE_TASK_OUTPUTS: [], self.TABLE_TASK_TIMEOUT_TIMERS: [], self.TABLE_XTRIGGERS: []} self.db_updates_map = {} def checkpoint(self, name): """Checkpoint the task pool, etc.""" return self.pri_dao.take_checkpoints(name, other_daos=[self.pub_dao]) def copy_pri_to_pub(self): """Copy content of primary database file to public database file. Use temporary file to ensure that we do not end up with a partial file. """ temp_pub_db_file_name = None self.pub_dao.close() try: self.pub_dao.conn = None # reset connection open(self.pub_dao.db_file_name, "a").close() # touch st_mode = os.stat(self.pub_dao.db_file_name).st_mode temp_pub_db_file_name = mkstemp( prefix=self.pub_dao.DB_FILE_BASE_NAME, dir=os.path.dirname(self.pub_dao.db_file_name))[1] copy(self.pri_dao.db_file_name, temp_pub_db_file_name) os.rename(temp_pub_db_file_name, self.pub_dao.db_file_name) os.chmod(self.pub_dao.db_file_name, st_mode) except (IOError, OSError): if temp_pub_db_file_name: os.unlink(temp_pub_db_file_name) raise def get_pri_dao(self): """Return the primary DAO.""" return CylcSuiteDAO(self.pri_path) @staticmethod def _namedtuple2json(obj): """Convert nametuple obj to a JSON string. Arguments: obj (namedtuple): input object to serialize to JSON. Return (str): Serialized JSON string of input object in the form "[type, list]". """ if obj is None: return json.dumps(None) else: return json.dumps([type(obj).__name__, obj.__getnewargs__()]) def on_suite_start(self, is_restart): """Initialise data access objects. Ensure that: * private database file is private * public database is in sync with private database """ if not is_restart: try: os.unlink(self.pri_path) except OSError: # Just in case the path is a directory! rmtree(self.pri_path, ignore_errors=True) self.pri_dao = self.get_pri_dao() os.chmod(self.pri_path, 0o600) self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True) self.copy_pri_to_pub() def on_suite_shutdown(self): """Close data access objects.""" if self.pri_dao: self.pri_dao.close() self.pri_dao = None if self.pub_dao: self.pub_dao.close() self.pub_dao = None def process_queued_ops(self): """Handle queued db operations for each task proxy.""" if self.pri_dao is None: return # Record suite parameters and tasks in pool # Record any broadcast settings to be dumped out if any(self.db_deletes_map.values()): for table_name, db_deletes in sorted( self.db_deletes_map.items()): while db_deletes: where_args = db_deletes.pop(0) self.pri_dao.add_delete_item(table_name, where_args) self.pub_dao.add_delete_item(table_name, where_args) if any(self.db_inserts_map.values()): for table_name, db_inserts in sorted( self.db_inserts_map.items()): while db_inserts: db_insert = db_inserts.pop(0) self.pri_dao.add_insert_item(table_name, db_insert) self.pub_dao.add_insert_item(table_name, db_insert) if (hasattr(self, 'db_updates_map') and any(self.db_updates_map.values())): for table_name, db_updates in sorted( self.db_updates_map.items()): while db_updates: set_args, where_args = db_updates.pop(0) self.pri_dao.add_update_item( table_name, set_args, where_args) self.pub_dao.add_update_item( table_name, set_args, where_args) # Previously, we used a separate thread for database writes. This has # now been removed. For the private database, there is no real # advantage in using a separate thread as it needs to be always in sync # with what is current. For the public database, which does not need to # be fully in sync, there is some advantage of using a separate # thread/process, if writing to it becomes a bottleneck. At the moment, # there is no evidence that this is a bottleneck, so it is better to # keep the logic simple. self.pri_dao.execute_queued_items() self.pub_dao.execute_queued_items() def put_broadcast(self, modified_settings, is_cancel=False): """Put or clear broadcasts in runtime database.""" now = get_current_time_string(display_sub_seconds=True) for broadcast_change in ( get_broadcast_change_iter(modified_settings, is_cancel)): broadcast_change["time"] = now self.db_inserts_map[self.TABLE_BROADCAST_EVENTS].append( broadcast_change) if is_cancel: self.db_deletes_map[self.TABLE_BROADCAST_STATES].append({ "point": broadcast_change["point"], "namespace": broadcast_change["namespace"], "key": broadcast_change["key"]}) # Delete statements are currently executed before insert # statements, so we should clear out any insert statements that # are deleted here. # (Not the most efficient logic here, but unless we have a # large number of inserts, then this should not be a big # concern.) inserts = [] for insert in self.db_inserts_map[self.TABLE_BROADCAST_STATES]: if any(insert[key] != broadcast_change[key] for key in ["point", "namespace", "key"]): inserts.append(insert) self.db_inserts_map[self.TABLE_BROADCAST_STATES] = inserts else: self.db_inserts_map[self.TABLE_BROADCAST_STATES].append({ "point": broadcast_change["point"], "namespace": broadcast_change["namespace"], "key": broadcast_change["key"], "value": broadcast_change["value"]}) def put_runtime_inheritance(self, config): """Put task/family inheritance in runtime database.""" for namespace in config.cfg['runtime']: value = config.runtime['linearized ancestors'][namespace] self.db_inserts_map[self.TABLE_INHERITANCE].append({ "namespace": namespace, "inheritance": json.dumps(value)}) def put_suite_params(self, schd): """Put various suite parameters from schd in runtime database. This method queues the relevant insert statements. Arguments: schd (cylc.flow.scheduler.Scheduler): scheduler object. """ if schd.final_point is None: # Store None as proper null value in database. No need to do this # for initial cycle point, which should never be None. final_point_str = None else: final_point_str = str(schd.final_point) self.db_inserts_map[self.TABLE_SUITE_PARAMS].extend([ {"key": "uuid_str", "value": str(schd.uuid_str)}, {"key": "run_mode", "value": schd.run_mode}, {"key": "cylc_version", "value": CYLC_VERSION}, {"key": "UTC_mode", "value": get_utc_mode()}, {"key": "initial_point", "value": str(schd.initial_point)}, {"key": "final_point", "value": final_point_str}, ]) if schd.config.cfg['cylc']['cycle point format']: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": "cycle_point_format", "value": schd.config.cfg['cylc']['cycle point format']}) if schd.pool.is_held: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": "is_held", "value": 1}) if schd.cli_start_point_string: self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({ "key": "start_point", "value": schd.cli_start_point_string}) def put_suite_template_vars(self, template_vars): """Put template_vars in runtime database. This method queues the relevant insert statements. """ for key, value in template_vars.items(): self.db_inserts_map[self.TABLE_SUITE_TEMPLATE_VARS].append( {"key": key, "value": value}) def put_task_event_timers(self, task_events_mgr): """Put statements to update the task_action_timers table.""" if task_events_mgr.event_timers: self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({}) for key, timer in task_events_mgr.event_timers.items(): key1, point, name, submit_num = key self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": name, "cycle": point, "ctx_key": json.dumps((key1, submit_num,)), "ctx": self._namedtuple2json(timer.ctx), "delays": json.dumps(timer.delays), "num": timer.num, "delay": timer.delay, "timeout": timer.timeout}) def put_xtriggers(self, sat_xtrig): """Put statements to update external triggers table.""" self.db_deletes_map[self.TABLE_XTRIGGERS].append({}) for sig, res in sat_xtrig.items(): self.db_inserts_map[self.TABLE_XTRIGGERS].append({ "signature": sig, "results": json.dumps(res)}) def put_task_pool(self, pool): """Put statements to update the task_pool table in runtime database. Update the task_pool table and the task_action_timers table. Queue delete (everything) statements to wipe the tables, and queue the relevant insert statements for the current tasks in the pool. """ self.db_deletes_map[self.TABLE_TASK_POOL].append({}) # No need to do: # self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({}) # Should already be done by self.put_task_event_timers above. self.db_deletes_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({}) for itask in pool.get_all_tasks(): self.db_inserts_map[self.TABLE_TASK_POOL].append({ "name": itask.tdef.name, "cycle": str(itask.point), "spawned": int(itask.has_spawned), "status": itask.state.status, "hold_swap": itask.state.hold_swap}) if itask.timeout is not None: self.db_inserts_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "timeout": itask.timeout}) if itask.poll_timer is not None: self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "ctx_key": json.dumps("poll_timer"), "ctx": self._namedtuple2json(itask.poll_timer.ctx), "delays": json.dumps(itask.poll_timer.delays), "num": itask.poll_timer.num, "delay": itask.poll_timer.delay, "timeout": itask.poll_timer.timeout}) for ctx_key_1, timer in itask.try_timers.items(): if timer is None: continue self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({ "name": itask.tdef.name, "cycle": str(itask.point), "ctx_key": json.dumps(("try_timers", ctx_key_1)), "ctx": self._namedtuple2json(timer.ctx), "delays": json.dumps(timer.delays), "num": timer.num, "delay": timer.delay, "timeout": timer.timeout}) if itask.state.time_updated: set_args = { "time_updated": itask.state.time_updated, "submit_num": itask.submit_num, "try_num": itask.get_try_num(), "status": itask.state.status} where_args = { "cycle": str(itask.point), "name": itask.tdef.name, } self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args)) itask.state.time_updated = None self.db_inserts_map[self.TABLE_CHECKPOINT_ID].append({ # id = -1 for latest "id": CylcSuiteDAO.CHECKPOINT_LATEST_ID, "time": get_current_time_string(), "event": CylcSuiteDAO.CHECKPOINT_LATEST_EVENT}) def put_insert_task_events(self, itask, args): """Put INSERT statement for task_events table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_EVENTS, itask, args) def put_insert_task_late_flags(self, itask): """If itask is late, put INSERT statement to task_late_flags table.""" if itask.is_late: self._put_insert_task_x( CylcSuiteDAO.TABLE_TASK_LATE_FLAGS, itask, {"value": True}) def put_insert_task_jobs(self, itask, args): """Put INSERT statement for task_jobs table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, args) def put_insert_task_states(self, itask, args): """Put INSERT statement for task_states table.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_STATES, itask, args) def put_insert_task_outputs(self, itask): """Reset custom outputs for a task.""" self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {}) def _put_insert_task_x(self, table_name, itask, args): """Put INSERT statement for a task_* table.""" args.update({ "name": itask.tdef.name, "cycle": str(itask.point)}) if "submit_num" not in args: args["submit_num"] = itask.submit_num self.db_inserts_map.setdefault(table_name, []) self.db_inserts_map[table_name].append(args) def put_update_task_jobs(self, itask, set_args): """Put UPDATE statement for task_jobs table.""" self._put_update_task_x( CylcSuiteDAO.TABLE_TASK_JOBS, itask, set_args) def put_update_task_outputs(self, itask): """Put UPDATE statement for task_outputs table.""" items = {} for trigger, message in itask.state.outputs.get_completed_customs(): items[trigger] = message self._put_update_task_x( CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {"outputs": json.dumps(items)}) def _put_update_task_x(self, table_name, itask, set_args): """Put UPDATE statement for a task_* table.""" where_args = { "cycle": str(itask.point), "name": itask.tdef.name} if "submit_num" not in set_args: where_args["submit_num"] = itask.submit_num self.db_updates_map.setdefault(table_name, []) self.db_updates_map[table_name].append((set_args, where_args)) def recover_pub_from_pri(self): """Recover public database from private database.""" if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES: self.copy_pri_to_pub() LOG.warning( "%(pub_db_name)s: recovered from %(pri_db_name)s" % { "pub_db_name": self.pub_dao.db_file_name, "pri_db_name": self.pri_dao.db_file_name}) self.pub_dao.n_tries = 0 def restart_upgrade(self): """Vacuum/upgrade runtime DB on restart.""" pri_dao = self.get_pri_dao() pri_dao.vacuum() pri_dao.close()
def _get_dao(suite): """Return the DAO (public) for suite.""" return CylcSuiteDAO(get_suite_run_pub_db_name(suite), is_public=True)
def test_upgrade_to_platforms(mock_glbl_cfg): """Test upgrader logic for platforms in the database. """ # Set up the global config mock_glbl_cfg('cylc.flow.rundb.glbl_cfg', GLOBAL_CONFIG) # task name, cycle, user_at_host, batch_system initial_data = [ ('hpc_with_pbs', '1', 'hpcl1', 'pbs'), ('desktop_with_bg', '1', 'desktop01', 'background'), ('slurm_no_host', '1', '', 'slurm'), ('hpc_bg', '1', 'hpcl1', 'background'), ('username_given', '1', 'slartibartfast@hpcl1', 'pbs') ] # task name, cycle, user, platform expected_data = [ ('hpc_with_pbs', '1', '', 'hpc'), ('desktop_with_bg', '1', '', 'desktop01'), ('slurm_no_host', '1', '', 'sugar'), ('hpc_bg', '1', '', 'hpcl1-bg'), ('username_given', '1', 'slartibartfast', 'hpc'), ] with create_temp_db() as (temp_db, conn): conn.execute( rf''' CREATE TABLE {CylcSuiteDAO.TABLE_TASK_JOBS} ( name varchar(255), cycle varchar(255), user_at_host varchar(255), batch_system varchar(255) ) ''' ) conn.executemany( rf''' INSERT INTO {CylcSuiteDAO.TABLE_TASK_JOBS} VALUES (?,?,?,?) ''', initial_data ) # close database conn.commit() conn.close() # open database as cylc dao dao = CylcSuiteDAO(temp_db) conn = dao.connect() # check the initial data was correctly inserted dump = [ x for x in conn.execute( rf'SELECT * FROM {CylcSuiteDAO.TABLE_TASK_JOBS}' ) ] assert dump == initial_data # Upgrade function returns True? assert dao.upgrade_to_platforms() # check the data was correctly upgraded dump = [ x for x in conn.execute( r'SELECT name, cycle, user, platform_name FROM task_jobs' ) ] assert dump == expected_data # make sure the upgrade is skipped on future runs assert not dao.upgrade_to_platforms()