Example #1
0
    def on_suite_start(self, is_restart):
        """Initialise data access objects.

        Ensure that:
        * private database file is private
        * public database is in sync with private database
        """
        if not is_restart:
            try:
                os.unlink(self.pri_path)
            except OSError:
                # Just in case the path is a directory!
                rmtree(self.pri_path, ignore_errors=True)
        self.pri_dao = self.get_pri_dao()
        os.chmod(self.pri_path, 0600)
        self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True)
        self.copy_pri_to_pub()
        pub_db_path_symlink = os.path.join(
            os.path.dirname(os.path.dirname(self.pub_path)),
            CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        try:
            orig_source = os.readlink(pub_db_path_symlink)
        except OSError:
            orig_source = None
        source = os.path.join('log', CylcSuiteDAO.DB_FILE_BASE_NAME)
        if orig_source != source:
            try:
                os.unlink(pub_db_path_symlink)
            except OSError:
                pass
            os.symlink(source, pub_db_path_symlink)
Example #2
0
class TestRunDb(unittest.TestCase):
    def setUp(self):
        self.dao = CylcSuiteDAO(':memory:')
        self.mocked_connection = mock.Mock()
        self.dao.connect = mock.MagicMock(return_value=self.mocked_connection)

    get_select_task_job = [
        ["cycle", "name", "NN"],
        ["cycle", "name", None],
        ["cycle", "name", "02"],
    ]

    def test_select_task_job(self):
        """Test the rundb CylcSuiteDAO select_task_job method"""
        columns = self.dao.tables[CylcSuiteDAO.TABLE_TASK_JOBS].columns[3:]
        expected_values = [[2 for _ in columns]]

        self.mocked_connection.execute.return_value = expected_values

        # parameterized test
        for cycle, name, submit_num in self.get_select_task_job:
            returned_values = self.dao.select_task_job(cycle, name, submit_num)

            for column in columns:
                self.assertEqual(2, returned_values[column.name])

    def test_select_task_job_sqlite_error(self):
        """Test that when the rundb CylcSuiteDAO select_task_job method raises
        a SQLite exception, the method returns None"""

        self.mocked_connection.execute.side_effect = DatabaseError

        r = self.dao.select_task_job("it'll", "raise", "an error!")
        self.assertIsNone(r)
Example #3
0
class TestRunDb(unittest.TestCase):

    def setUp(self):
        self.dao = CylcSuiteDAO(':memory:')
        self.mocked_connection = mock.Mock()
        self.dao.connect = mock.MagicMock(return_value=self.mocked_connection)

    get_select_task_job = [
        ["cycle", "name", "NN"],
        ["cycle", "name", None],
        ["cycle", "name", "02"],
    ]

    def test_select_task_job(self):
        """Test the rundb CylcSuiteDAO select_task_job method"""
        columns = self.dao.tables[CylcSuiteDAO.TABLE_TASK_JOBS].columns[3:]
        expected_values = [[2 for _ in columns]]

        self.mocked_connection.execute.return_value = expected_values

        # parameterized test
        for cycle, name, submit_num in self.get_select_task_job:
            returned_values = self.dao.select_task_job(cycle, name, submit_num)

            for column in columns:
                self.assertEqual(2, returned_values[column.name])

    def test_select_task_job_sqlite_error(self):
        """Test that when the rundb CylcSuiteDAO select_task_job method raises
        a SQLite exception, the method returns None"""

        self.mocked_connection.execute.side_effect = DatabaseError

        r = self.dao.select_task_job("it'll", "raise", "an error!")
        self.assertIsNone(r)
Example #4
0
    def on_suite_start(self, is_restart):
        """Initialise data access objects.

        Ensure that:
        * private database file is private
        * public database is in sync with private database
        """
        if not is_restart:
            try:
                os.unlink(self.pri_path)
            except OSError:
                # Just in case the path is a directory!
                rmtree(self.pri_path, ignore_errors=True)
        self.pri_dao = self.get_pri_dao()
        os.chmod(self.pri_path, 0600)
        self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True)
        self.copy_pri_to_pub()
        pub_db_path_symlink = os.path.join(
            os.path.dirname(os.path.dirname(self.pub_path)),
            CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        try:
            orig_source = os.readlink(pub_db_path_symlink)
        except OSError:
            orig_source = None
        source = os.path.join('log', CylcSuiteDAO.DB_FILE_BASE_NAME)
        if orig_source != source:
            try:
                os.unlink(pub_db_path_symlink)
            except OSError:
                pass
            os.symlink(source, pub_db_path_symlink)
Example #5
0
    def on_suite_start(self, is_restart):
        """Initialise data access objects.

        Ensure that:
        * private database file is private
        * public database is in sync with private database
        """
        if not is_restart:
            try:
                os.unlink(self.pri_path)
            except OSError:
                # Just in case the path is a directory!
                rmtree(self.pri_path, ignore_errors=True)
        self.pri_dao = self.get_pri_dao()
        os.chmod(self.pri_path, 0o600)
        self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True)
        self.copy_pri_to_pub()
Example #6
0
 def get_suite_broadcast_states(self, user_name, suite_name):
     """Return broadcast states of a suite.
     [[point, name, key, value], ...]
     """
     stmt = CylcSuiteDAO.pre_select_broadcast_states(self._db_init(
         user_name, suite_name),
                                                     order="ASC")[0]
     broadcast_states = []
     for row in self._db_exec(user_name, suite_name, stmt):
         point, namespace, key, value = row
         broadcast_states.append([point, namespace, key, value])
     return broadcast_states
Example #7
0
 def get_suite_broadcast_events(self, user_name, suite_name):
     """Return broadcast events of a suite.
     [[time, change, point, name, key, value], ...]
     """
     stmt = CylcSuiteDAO.pre_select_broadcast_events(self._db_init(
         user_name, suite_name),
                                                     order="DESC")[0]
     broadcast_events = []
     for row in self._db_exec(user_name, suite_name, stmt):
         time_, change, point, namespace, key, value = row
         broadcast_events.append(
             (time_, change, point, namespace, key, value))
     return broadcast_events
Example #8
0
 def _db_init(self, user_name, suite_name):
     """Initialise a named CylcSuiteDAO database connection."""
     key = (user_name, suite_name)
     if key not in self.daos:
         prefix = "~"
         if user_name:
             prefix += user_name
         for name in [os.path.join("log", "db"), "cylc-suite.db"]:
             db_f_name = os.path.expanduser(
                 os.path.join(prefix,
                              os.path.join("cylc-run", suite_name, name)))
             self.daos[key] = CylcSuiteDAO(db_f_name, is_public=True)
             if os.path.exists(db_f_name):
                 break
     return self.daos[key]
Example #9
0
class SuiteDatabaseManager(object):
    """Manage the suite runtime private and public databases."""

    TABLE_BROADCAST_EVENTS = CylcSuiteDAO.TABLE_BROADCAST_EVENTS
    TABLE_BROADCAST_STATES = CylcSuiteDAO.TABLE_BROADCAST_STATES
    TABLE_CHECKPOINT_ID = CylcSuiteDAO.TABLE_CHECKPOINT_ID
    TABLE_INHERITANCE = CylcSuiteDAO.TABLE_INHERITANCE
    TABLE_SUITE_PARAMS = CylcSuiteDAO.TABLE_SUITE_PARAMS
    TABLE_SUITE_TEMPLATE_VARS = CylcSuiteDAO.TABLE_SUITE_TEMPLATE_VARS
    TABLE_TASK_ACTION_TIMERS = CylcSuiteDAO.TABLE_TASK_ACTION_TIMERS
    TABLE_TASK_POOL = CylcSuiteDAO.TABLE_TASK_POOL
    TABLE_TASK_OUTPUTS = CylcSuiteDAO.TABLE_TASK_OUTPUTS
    TABLE_TASK_STATES = CylcSuiteDAO.TABLE_TASK_STATES
    TABLE_TASK_TIMEOUT_TIMERS = CylcSuiteDAO.TABLE_TASK_TIMEOUT_TIMERS
    TABLE_XTRIGGERS = CylcSuiteDAO.TABLE_XTRIGGERS

    def __init__(self, pri_d=None, pub_d=None):
        self.pri_path = None
        if pri_d:
            self.pri_path = os.path.join(pri_d, CylcSuiteDAO.DB_FILE_BASE_NAME)
        self.pub_path = None
        if pub_d:
            self.pub_path = os.path.join(pub_d, CylcSuiteDAO.DB_FILE_BASE_NAME)
        self.pri_dao = None
        self.pub_dao = None

        self.db_deletes_map = {
            self.TABLE_BROADCAST_STATES: [],
            self.TABLE_SUITE_PARAMS: [],
            self.TABLE_TASK_POOL: [],
            self.TABLE_TASK_ACTION_TIMERS: [],
            self.TABLE_TASK_OUTPUTS: [],
            self.TABLE_TASK_TIMEOUT_TIMERS: [],
            self.TABLE_XTRIGGERS: []
        }
        self.db_inserts_map = {
            self.TABLE_BROADCAST_EVENTS: [],
            self.TABLE_BROADCAST_STATES: [],
            self.TABLE_INHERITANCE: [],
            self.TABLE_SUITE_PARAMS: [],
            self.TABLE_SUITE_TEMPLATE_VARS: [],
            self.TABLE_CHECKPOINT_ID: [],
            self.TABLE_TASK_POOL: [],
            self.TABLE_TASK_ACTION_TIMERS: [],
            self.TABLE_TASK_OUTPUTS: [],
            self.TABLE_TASK_TIMEOUT_TIMERS: [],
            self.TABLE_XTRIGGERS: []
        }
        self.db_updates_map = {}

    def checkpoint(self, name):
        """Checkpoint the task pool, etc."""
        return self.pri_dao.take_checkpoints(name, other_daos=[self.pub_dao])

    def copy_pri_to_pub(self):
        """Copy content of primary database file to public database file.

        Use temporary file to ensure that we do not end up with a partial file.

        """
        temp_pub_db_file_name = None
        self.pub_dao.close()
        try:
            self.pub_dao.conn = None  # reset connection
            open(self.pub_dao.db_file_name, "a").close()  # touch
            st_mode = os.stat(self.pub_dao.db_file_name).st_mode
            temp_pub_db_file_name = mkstemp(
                prefix=self.pub_dao.DB_FILE_BASE_NAME,
                dir=os.path.dirname(self.pub_dao.db_file_name))[1]
            copy(self.pri_dao.db_file_name, temp_pub_db_file_name)
            os.rename(temp_pub_db_file_name, self.pub_dao.db_file_name)
            os.chmod(self.pub_dao.db_file_name, st_mode)
        except (IOError, OSError):
            if temp_pub_db_file_name:
                os.unlink(temp_pub_db_file_name)
            raise

    def get_pri_dao(self):
        """Return the primary DAO."""
        return CylcSuiteDAO(self.pri_path)

    @staticmethod
    def _namedtuple2json(obj):
        """Convert nametuple obj to a JSON string.

        Arguments:
            obj (namedtuple): input object to serialize to JSON.

        Return (str):
            Serialized JSON string of input object in the form "[type, list]".
        """
        if obj is None:
            return json.dumps(None)
        else:
            return json.dumps([type(obj).__name__, obj.__getnewargs__()])

    def on_suite_start(self, is_restart):
        """Initialise data access objects.

        Ensure that:
        * private database file is private
        * public database is in sync with private database
        """
        if not is_restart:
            try:
                os.unlink(self.pri_path)
            except OSError:
                # Just in case the path is a directory!
                rmtree(self.pri_path, ignore_errors=True)
        self.pri_dao = self.get_pri_dao()
        os.chmod(self.pri_path, 0600)
        self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True)
        self.copy_pri_to_pub()
        pub_db_path_symlink = os.path.join(
            os.path.dirname(os.path.dirname(self.pub_path)),
            CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        try:
            orig_source = os.readlink(pub_db_path_symlink)
        except OSError:
            orig_source = None
        source = os.path.join('log', CylcSuiteDAO.DB_FILE_BASE_NAME)
        if orig_source != source:
            try:
                os.unlink(pub_db_path_symlink)
            except OSError:
                pass
            os.symlink(source, pub_db_path_symlink)

    def on_suite_shutdown(self):
        """Close data access objects."""
        if self.pri_dao:
            self.pri_dao.close()
            self.pri_dao = None
        if self.pub_dao:
            self.pub_dao.close()
            self.pub_dao = None

    def process_queued_ops(self):
        """Handle queued db operations for each task proxy."""
        if self.pri_dao is None:
            return
        # Record suite parameters and tasks in pool
        # Record any broadcast settings to be dumped out
        if any(self.db_deletes_map.values()):
            for table_name, db_deletes in sorted(self.db_deletes_map.items()):
                while db_deletes:
                    where_args = db_deletes.pop(0)
                    self.pri_dao.add_delete_item(table_name, where_args)
                    self.pub_dao.add_delete_item(table_name, where_args)
        if any(self.db_inserts_map.values()):
            for table_name, db_inserts in sorted(self.db_inserts_map.items()):
                while db_inserts:
                    db_insert = db_inserts.pop(0)
                    self.pri_dao.add_insert_item(table_name, db_insert)
                    self.pub_dao.add_insert_item(table_name, db_insert)
        if (hasattr(self, 'db_updates_map')
                and any(self.db_updates_map.values())):
            for table_name, db_updates in sorted(self.db_updates_map.items()):
                while db_updates:
                    set_args, where_args = db_updates.pop(0)
                    self.pri_dao.add_update_item(table_name, set_args,
                                                 where_args)
                    self.pub_dao.add_update_item(table_name, set_args,
                                                 where_args)

        # Previously, we used a separate thread for database writes. This has
        # now been removed. For the private database, there is no real
        # advantage in using a separate thread as it needs to be always in sync
        # with what is current. For the public database, which does not need to
        # be fully in sync, there is some advantage of using a separate
        # thread/process, if writing to it becomes a bottleneck. At the moment,
        # there is no evidence that this is a bottleneck, so it is better to
        # keep the logic simple.
        self.pri_dao.execute_queued_items()
        self.pub_dao.execute_queued_items()

    def put_broadcast(self, modified_settings, is_cancel=False):
        """Put or clear broadcasts in runtime database."""
        now = get_current_time_string(display_sub_seconds=True)
        for broadcast_change in (get_broadcast_change_iter(
                modified_settings, is_cancel)):
            broadcast_change["time"] = now
            self.db_inserts_map[self.TABLE_BROADCAST_EVENTS].append(
                broadcast_change)
            if is_cancel:
                self.db_deletes_map[self.TABLE_BROADCAST_STATES].append({
                    "point":
                    broadcast_change["point"],
                    "namespace":
                    broadcast_change["namespace"],
                    "key":
                    broadcast_change["key"]
                })
                # Delete statements are currently executed before insert
                # statements, so we should clear out any insert statements that
                # are deleted here.
                # (Not the most efficient logic here, but unless we have a
                # large number of inserts, then this should not be a big
                # concern.)
                inserts = []
                for insert in self.db_inserts_map[self.TABLE_BROADCAST_STATES]:
                    if any(insert[key] != broadcast_change[key]
                           for key in ["point", "namespace", "key"]):
                        inserts.append(insert)
                self.db_inserts_map[self.TABLE_BROADCAST_STATES] = inserts
            else:
                self.db_inserts_map[self.TABLE_BROADCAST_STATES].append({
                    "point":
                    broadcast_change["point"],
                    "namespace":
                    broadcast_change["namespace"],
                    "key":
                    broadcast_change["key"],
                    "value":
                    broadcast_change["value"]
                })

    def put_runtime_inheritance(self, config):
        """Put task/family inheritance in runtime database."""
        for namespace in config.cfg['runtime']:
            value = config.runtime['linearized ancestors'][namespace]
            self.db_inserts_map[self.TABLE_INHERITANCE].append({
                "namespace":
                namespace,
                "inheritance":
                json.dumps(value)
            })

    def put_suite_params(self, schd):
        """Put various suite parameters from schd in runtime database.

        This method queues the relevant insert statements.

        Arguments:
            schd (cylc.scheduler.Scheduler): scheduler object.
        """
        if schd.final_point is None:
            # Store None as proper null value in database. No need to do this
            # for initial cycle point, which should never be None.
            final_point_str = None
        else:
            final_point_str = str(schd.final_point)
        self.db_inserts_map[self.TABLE_SUITE_PARAMS].extend([
            {
                "key": "uuid_str",
                "value": str(schd.uuid_str)
            },
            {
                "key": "run_mode",
                "value": schd.run_mode
            },
            {
                "key": "cylc_version",
                "value": CYLC_VERSION
            },
            {
                "key": "UTC_mode",
                "value": get_utc_mode()
            },
            {
                "key": "initial_point",
                "value": str(schd.initial_point)
            },
            {
                "key": "final_point",
                "value": final_point_str
            },
        ])
        if schd.config.cfg['cylc']['cycle point format']:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key":
                "cycle_point_format",
                "value":
                schd.config.cfg['cylc']['cycle point format']
            })
        if schd.pool.is_held:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key": "is_held",
                "value": 1
            })
        if schd.cli_start_point_string:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key":
                "start_point",
                "value":
                schd.cli_start_point_string
            })

    def put_suite_template_vars(self, template_vars):
        """Put template_vars in runtime database.

        This method queues the relevant insert statements.
        """
        for key, value in template_vars.items():
            self.db_inserts_map[self.TABLE_SUITE_TEMPLATE_VARS].append({
                "key":
                key,
                "value":
                value
            })

    def put_task_event_timers(self, task_events_mgr):
        """Put statements to update the task_action_timers table."""
        if task_events_mgr.event_timers:
            self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({})
            for key, timer in task_events_mgr.event_timers.items():
                key1, point, name, submit_num = key
                self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({
                    "name":
                    name,
                    "cycle":
                    point,
                    "ctx_key":
                    json.dumps((
                        key1,
                        submit_num,
                    )),
                    "ctx":
                    self._namedtuple2json(timer.ctx),
                    "delays":
                    json.dumps(timer.delays),
                    "num":
                    timer.num,
                    "delay":
                    timer.delay,
                    "timeout":
                    timer.timeout
                })

    def put_xtriggers(self, sat_xtrig):
        """Put statements to update external triggers table."""
        self.db_deletes_map[self.TABLE_XTRIGGERS].append({})
        for sig, res in sat_xtrig.items():
            self.db_inserts_map[self.TABLE_XTRIGGERS].append({
                "signature":
                sig,
                "results":
                json.dumps(res)
            })

    def put_task_pool(self, pool):
        """Put statements to update the task_pool table in runtime database.

        Update the task_pool table and the task_action_timers table.
        Queue delete (everything) statements to wipe the tables, and queue the
        relevant insert statements for the current tasks in the pool.
        """
        self.db_deletes_map[self.TABLE_TASK_POOL].append({})
        # No need to do:
        # self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({})
        # Should already be done by self.put_task_event_timers above.
        self.db_deletes_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({})
        for itask in pool.get_all_tasks():
            self.db_inserts_map[self.TABLE_TASK_POOL].append({
                "name":
                itask.tdef.name,
                "cycle":
                str(itask.point),
                "spawned":
                int(itask.has_spawned),
                "status":
                itask.state.status,
                "hold_swap":
                itask.state.hold_swap
            })
            if itask.timeout is not None:
                self.db_inserts_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({
                    "name":
                    itask.tdef.name,
                    "cycle":
                    str(itask.point),
                    "timeout":
                    itask.timeout
                })
            if itask.poll_timer is not None:
                self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({
                    "name":
                    itask.tdef.name,
                    "cycle":
                    str(itask.point),
                    "ctx_key":
                    json.dumps("poll_timer"),
                    "ctx":
                    self._namedtuple2json(itask.poll_timer.ctx),
                    "delays":
                    json.dumps(itask.poll_timer.delays),
                    "num":
                    itask.poll_timer.num,
                    "delay":
                    itask.poll_timer.delay,
                    "timeout":
                    itask.poll_timer.timeout
                })
            for ctx_key_1, timer in itask.try_timers.items():
                if timer is None:
                    continue
                self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({
                    "name":
                    itask.tdef.name,
                    "cycle":
                    str(itask.point),
                    "ctx_key":
                    json.dumps(("try_timers", ctx_key_1)),
                    "ctx":
                    self._namedtuple2json(timer.ctx),
                    "delays":
                    json.dumps(timer.delays),
                    "num":
                    timer.num,
                    "delay":
                    timer.delay,
                    "timeout":
                    timer.timeout
                })
            if itask.state.time_updated:
                set_args = {
                    "time_updated": itask.state.time_updated,
                    "submit_num": itask.submit_num,
                    "try_num": itask.get_try_num(),
                    "status": itask.state.status
                }
                where_args = {
                    "cycle": str(itask.point),
                    "name": itask.tdef.name,
                }
                self.db_updates_map.setdefault(self.TABLE_TASK_STATES, [])
                self.db_updates_map[self.TABLE_TASK_STATES].append(
                    (set_args, where_args))
                itask.state.time_updated = None

        self.db_inserts_map[self.TABLE_CHECKPOINT_ID].append({
            # id = -1 for latest
            "id":
            CylcSuiteDAO.CHECKPOINT_LATEST_ID,
            "time":
            get_current_time_string(),
            "event":
            CylcSuiteDAO.CHECKPOINT_LATEST_EVENT
        })

    def put_insert_task_events(self, itask, args):
        """Put INSERT statement for task_events table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_EVENTS, itask, args)

    def put_insert_task_late_flags(self, itask):
        """If itask is late, put INSERT statement to task_late_flags table."""
        if itask.is_late:
            self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_LATE_FLAGS, itask,
                                    {"value": True})

    def put_insert_task_jobs(self, itask, args):
        """Put INSERT statement for task_jobs table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, args)

    def put_insert_task_states(self, itask, args):
        """Put INSERT statement for task_states table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_STATES, itask, args)

    def put_insert_task_outputs(self, itask):
        """Reset custom outputs for a task."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {})

    def _put_insert_task_x(self, table_name, itask, args):
        """Put INSERT statement for a task_* table."""
        args.update({"name": itask.tdef.name, "cycle": str(itask.point)})
        if "submit_num" not in args:
            args["submit_num"] = itask.submit_num
        self.db_inserts_map.setdefault(table_name, [])
        self.db_inserts_map[table_name].append(args)

    def put_update_task_jobs(self, itask, set_args):
        """Put UPDATE statement for task_jobs table."""
        self._put_update_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, set_args)

    def put_update_task_outputs(self, itask):
        """Put UPDATE statement for task_outputs table."""
        items = {}
        for trigger, message in itask.state.outputs.get_completed_customs():
            items[trigger] = message
        self._put_update_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask,
                                {"outputs": json.dumps(items)})

    def _put_update_task_x(self, table_name, itask, set_args):
        """Put UPDATE statement for a task_* table."""
        where_args = {"cycle": str(itask.point), "name": itask.tdef.name}
        if "submit_num" not in set_args:
            where_args["submit_num"] = itask.submit_num
        self.db_updates_map.setdefault(table_name, [])
        self.db_updates_map[table_name].append((set_args, where_args))

    def recover_pub_from_pri(self):
        """Recover public database from private database."""
        if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES:
            self.copy_pri_to_pub()
            LOG.warning(
                "%(pub_db_name)s: recovered from %(pri_db_name)s" % {
                    "pub_db_name": self.pub_dao.db_file_name,
                    "pri_db_name": self.pri_dao.db_file_name
                })
            self.pub_dao.n_tries = 0

    def restart_upgrade(self):
        """Vacuum/upgrade runtime DB on restart."""
        # Backward compat, upgrade database with state file if necessary
        suite_run_d = os.path.dirname(os.path.dirname(self.pub_path))
        old_pri_db_path = os.path.join(suite_run_d, 'state',
                                       CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        old_pri_db_path_611 = os.path.join(
            suite_run_d, CylcSuiteDAO.OLD_DB_FILE_BASE_NAME_611[0])
        old_state_file_path = os.path.join(suite_run_d, "state", "state")
        if (os.path.exists(old_pri_db_path)
                and os.path.exists(old_state_file_path)
                and not os.path.exists(self.pri_path)):
            # Upgrade pre-6.11.X runtime database + state file
            copy(old_pri_db_path, self.pri_path)
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_with_state_file(old_state_file_path)
            target = os.path.join(suite_run_d, "state.tar.gz")
            cmd = ["tar", "-C", suite_run_d, "-czf", target, "state"]
            if call(cmd, stdin=open(os.devnull)) == 0:
                rmtree(os.path.join(suite_run_d, "state"), ignore_errors=True)
            else:
                try:
                    os.unlink(os.path.join(suite_run_d, "state.tar.gz"))
                except OSError:
                    pass
                LOG.error("cannot tar-gzip + remove old state/ directory")
            # Remove old files as well
            try:
                os.unlink(os.path.join(suite_run_d, "cylc-suite-env"))
            except OSError:
                pass
        elif (os.path.exists(old_pri_db_path_611)
              and not os.path.exists(self.pri_path)):
            # Upgrade 6.11.X runtime database
            os.rename(old_pri_db_path_611, self.pri_path)
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_from_611()
            # Remove old files as well
            for name in [
                    CylcSuiteDAO.OLD_DB_FILE_BASE_NAME_611[1], "cylc-suite-env"
            ]:
                try:
                    os.unlink(os.path.join(suite_run_d, name))
                except OSError:
                    pass
        else:
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_pickle_to_json()

        # Vacuum the primary/private database file
        pri_dao.vacuum()
        pri_dao.close()
Example #10
0
 def get_pri_dao(self):
     """Return the primary DAO."""
     return CylcSuiteDAO(self.pri_path)
Example #11
0
 def setUp(self):
     self.dao = CylcSuiteDAO(':memory:')
     self.mocked_connection = mock.Mock()
     self.dao.connect = mock.MagicMock(return_value=self.mocked_connection)
Example #12
0
class SuiteDatabaseManager(object):
    """Manage the suite runtime private and public databases."""

    TABLE_BROADCAST_EVENTS = CylcSuiteDAO.TABLE_BROADCAST_EVENTS
    TABLE_BROADCAST_STATES = CylcSuiteDAO.TABLE_BROADCAST_STATES
    TABLE_CHECKPOINT_ID = CylcSuiteDAO.TABLE_CHECKPOINT_ID
    TABLE_INHERITANCE = CylcSuiteDAO.TABLE_INHERITANCE
    TABLE_SUITE_PARAMS = CylcSuiteDAO.TABLE_SUITE_PARAMS
    TABLE_SUITE_TEMPLATE_VARS = CylcSuiteDAO.TABLE_SUITE_TEMPLATE_VARS
    TABLE_TASK_ACTION_TIMERS = CylcSuiteDAO.TABLE_TASK_ACTION_TIMERS
    TABLE_TASK_POOL = CylcSuiteDAO.TABLE_TASK_POOL
    TABLE_TASK_OUTPUTS = CylcSuiteDAO.TABLE_TASK_OUTPUTS
    TABLE_TASK_STATES = CylcSuiteDAO.TABLE_TASK_STATES
    TABLE_TASK_TIMEOUT_TIMERS = CylcSuiteDAO.TABLE_TASK_TIMEOUT_TIMERS

    def __init__(self, pri_d=None, pub_d=None):
        self.pri_path = None
        if pri_d:
            self.pri_path = os.path.join(pri_d, CylcSuiteDAO.DB_FILE_BASE_NAME)
        self.pub_path = None
        if pub_d:
            self.pub_path = os.path.join(pub_d, CylcSuiteDAO.DB_FILE_BASE_NAME)
        self.pri_dao = None
        self.pub_dao = None

        self.db_deletes_map = {
            self.TABLE_BROADCAST_STATES: [],
            self.TABLE_SUITE_PARAMS: [],
            self.TABLE_TASK_POOL: [],
            self.TABLE_TASK_ACTION_TIMERS: [],
            self.TABLE_TASK_OUTPUTS: [],
            self.TABLE_TASK_TIMEOUT_TIMERS: []}
        self.db_inserts_map = {
            self.TABLE_BROADCAST_EVENTS: [],
            self.TABLE_BROADCAST_STATES: [],
            self.TABLE_INHERITANCE: [],
            self.TABLE_SUITE_PARAMS: [],
            self.TABLE_SUITE_TEMPLATE_VARS: [],
            self.TABLE_CHECKPOINT_ID: [],
            self.TABLE_TASK_POOL: [],
            self.TABLE_TASK_ACTION_TIMERS: [],
            self.TABLE_TASK_OUTPUTS: [],
            self.TABLE_TASK_TIMEOUT_TIMERS: []}
        self.db_updates_map = {}

    def checkpoint(self, name):
        """Checkpoint the task pool, etc."""
        return self.pri_dao.take_checkpoints(name, other_daos=[self.pub_dao])

    def copy_pri_to_pub(self):
        """Copy content of primary database file to public database file.

        Use temporary file to ensure that we do not end up with a partial file.

        """
        temp_pub_db_file_name = None
        self.pub_dao.close()
        try:
            self.pub_dao.conn = None  # reset connection
            open(self.pub_dao.db_file_name, "a").close()  # touch
            st_mode = os.stat(self.pub_dao.db_file_name).st_mode
            temp_pub_db_file_name = mkstemp(
                prefix=self.pub_dao.DB_FILE_BASE_NAME,
                dir=os.path.dirname(self.pub_dao.db_file_name))[1]
            copy(self.pri_dao.db_file_name, temp_pub_db_file_name)
            os.rename(temp_pub_db_file_name, self.pub_dao.db_file_name)
            os.chmod(self.pub_dao.db_file_name, st_mode)
        except (IOError, OSError):
            if temp_pub_db_file_name:
                os.unlink(temp_pub_db_file_name)
            raise

    def get_pri_dao(self):
        """Return the primary DAO."""
        return CylcSuiteDAO(self.pri_path)

    @staticmethod
    def _namedtuple2json(obj):
        """Convert nametuple obj to a JSON string.

        Arguments:
            obj (namedtuple): input object to serialize to JSON.

        Return (str):
            Serialized JSON string of input object in the form "[type, list]".
        """
        if obj is None:
            return json.dumps(None)
        else:
            return json.dumps([type(obj).__name__, obj.__getnewargs__()])

    def on_suite_start(self, is_restart):
        """Initialise data access objects.

        Ensure that:
        * private database file is private
        * public database is in sync with private database
        """
        if not is_restart:
            try:
                os.unlink(self.pri_path)
            except OSError:
                # Just in case the path is a directory!
                rmtree(self.pri_path, ignore_errors=True)
        self.pri_dao = self.get_pri_dao()
        os.chmod(self.pri_path, 0600)
        self.pub_dao = CylcSuiteDAO(self.pub_path, is_public=True)
        self.copy_pri_to_pub()
        pub_db_path_symlink = os.path.join(
            os.path.dirname(os.path.dirname(self.pub_path)),
            CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        try:
            orig_source = os.readlink(pub_db_path_symlink)
        except OSError:
            orig_source = None
        source = os.path.join('log', CylcSuiteDAO.DB_FILE_BASE_NAME)
        if orig_source != source:
            try:
                os.unlink(pub_db_path_symlink)
            except OSError:
                pass
            os.symlink(source, pub_db_path_symlink)

    def on_suite_shutdown(self):
        """Close data access objects."""
        if self.pri_dao:
            self.pri_dao.close()
            self.pri_dao = None
        if self.pub_dao:
            self.pub_dao.close()
            self.pub_dao = None

    def process_queued_ops(self):
        """Handle queued db operations for each task proxy."""
        if self.pri_dao is None:
            return
        # Record suite parameters and tasks in pool
        # Record any broadcast settings to be dumped out
        if any(self.db_deletes_map.values()):
            for table_name, db_deletes in sorted(
                    self.db_deletes_map.items()):
                while db_deletes:
                    where_args = db_deletes.pop(0)
                    self.pri_dao.add_delete_item(table_name, where_args)
                    self.pub_dao.add_delete_item(table_name, where_args)
        if any(self.db_inserts_map.values()):
            for table_name, db_inserts in sorted(
                    self.db_inserts_map.items()):
                while db_inserts:
                    db_insert = db_inserts.pop(0)
                    self.pri_dao.add_insert_item(table_name, db_insert)
                    self.pub_dao.add_insert_item(table_name, db_insert)
        if (hasattr(self, 'db_updates_map') and
                any(self.db_updates_map.values())):
            for table_name, db_updates in sorted(
                    self.db_updates_map.items()):
                while db_updates:
                    set_args, where_args = db_updates.pop(0)
                    self.pri_dao.add_update_item(
                        table_name, set_args, where_args)
                    self.pub_dao.add_update_item(
                        table_name, set_args, where_args)

        # Previously, we used a separate thread for database writes. This has
        # now been removed. For the private database, there is no real
        # advantage in using a separate thread as it needs to be always in sync
        # with what is current. For the public database, which does not need to
        # be fully in sync, there is some advantage of using a separate
        # thread/process, if writing to it becomes a bottleneck. At the moment,
        # there is no evidence that this is a bottleneck, so it is better to
        # keep the logic simple.
        self.pri_dao.execute_queued_items()
        self.pub_dao.execute_queued_items()

    def put_broadcast(self, modified_settings, is_cancel=False):
        """Put or clear broadcasts in runtime database."""
        now = get_current_time_string(display_sub_seconds=True)
        for broadcast_change in (
                get_broadcast_change_iter(modified_settings, is_cancel)):
            broadcast_change["time"] = now
            self.db_inserts_map[self.TABLE_BROADCAST_EVENTS].append(
                broadcast_change)
            if is_cancel:
                self.db_deletes_map[self.TABLE_BROADCAST_STATES].append({
                    "point": broadcast_change["point"],
                    "namespace": broadcast_change["namespace"],
                    "key": broadcast_change["key"]})
                # Delete statements are currently executed before insert
                # statements, so we should clear out any insert statements that
                # are deleted here.
                # (Not the most efficient logic here, but unless we have a
                # large number of inserts, then this should not be a big
                # concern.)
                inserts = []
                for insert in self.db_inserts_map[self.TABLE_BROADCAST_STATES]:
                    if any(insert[key] != broadcast_change[key]
                            for key in ["point", "namespace", "key"]):
                        inserts.append(insert)
                self.db_inserts_map[self.TABLE_BROADCAST_STATES] = inserts
            else:
                self.db_inserts_map[self.TABLE_BROADCAST_STATES].append({
                    "point": broadcast_change["point"],
                    "namespace": broadcast_change["namespace"],
                    "key": broadcast_change["key"],
                    "value": broadcast_change["value"]})

    def put_runtime_inheritance(self, config):
        """Put task/family inheritance in runtime database."""
        for namespace in config.cfg['runtime']:
            value = config.runtime['linearized ancestors'][namespace]
            self.db_inserts_map[self.TABLE_INHERITANCE].append({
                "namespace": namespace,
                "inheritance": json.dumps(value)})

    def put_suite_params(self, schd):
        """Put various suite parameters from schd in runtime database.

        This method queues the relevant insert statements.

        Arguments:
            schd (cylc.scheduler.Scheduler): scheduler object.
        """
        self.db_inserts_map[self.TABLE_SUITE_PARAMS].extend([
            {"key": "uuid_str",
             "value": schd.task_job_mgr.task_remote_mgr.uuid_str},
            {"key": "run_mode", "value": schd.run_mode},
            {"key": "cylc_version", "value": CYLC_VERSION},
            {"key": "UTC_mode", "value": cylc.flags.utc},
            {"key": "initial_point", "value": str(schd.initial_point)},
            {"key": "final_point", "value": str(schd.final_point)},
        ])
        if schd.config.cfg['cylc']['cycle point format']:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key": "cycle_point_format",
                "value": schd.config.cfg['cylc']['cycle point format']})
        if schd.pool.is_held:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key": "is_held", "value": 1})
        if schd.cli_start_point_string:
            self.db_inserts_map[self.TABLE_SUITE_PARAMS].append({
                "key": "start_point", "value": schd.cli_start_point_string})

    def put_suite_template_vars(self, template_vars):
        """Put template_vars in runtime database.

        This method queues the relevant insert statements.
        """
        for key, value in template_vars.items():
            self.db_inserts_map[self.TABLE_SUITE_TEMPLATE_VARS].append(
                {"key": key, "value": value})

    def put_task_event_timers(self, task_events_mgr):
        """Put statements to update the task_action_timers table."""
        if task_events_mgr.event_timers:
            self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({})
            for key, timer in task_events_mgr.event_timers.items():
                key1, point, name, submit_num = key
                self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({
                    "name": name,
                    "cycle": point,
                    "ctx_key": json.dumps((key1, submit_num,)),
                    "ctx": self._namedtuple2json(timer.ctx),
                    "delays": json.dumps(timer.delays),
                    "num": timer.num,
                    "delay": timer.delay,
                    "timeout": timer.timeout})

    def put_task_pool(self, pool):
        """Put statements to update the task_pool table in runtime database.

        Update the task_pool table and the task_action_timers table.
        Queue delete (everything) statements to wipe the tables, and queue the
        relevant insert statements for the current tasks in the pool.
        """
        self.db_deletes_map[self.TABLE_TASK_POOL].append({})
        # No need to do:
        # self.db_deletes_map[self.TABLE_TASK_ACTION_TIMERS].append({})
        # Should already be done by self.put_task_event_timers above.
        self.db_deletes_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({})
        for itask in pool.get_all_tasks():
            self.db_inserts_map[self.TABLE_TASK_POOL].append({
                "name": itask.tdef.name,
                "cycle": str(itask.point),
                "spawned": int(itask.has_spawned),
                "status": itask.state.status,
                "hold_swap": itask.state.hold_swap})
            if itask.state.status in itask.timeout_timers:
                self.db_inserts_map[self.TABLE_TASK_TIMEOUT_TIMERS].append({
                    "name": itask.tdef.name,
                    "cycle": str(itask.point),
                    "timeout": itask.timeout_timers[itask.state.status]})
            for ctx_key_0 in ["poll_timers", "try_timers"]:
                for ctx_key_1, timer in getattr(itask, ctx_key_0).items():
                    if timer is None:
                        continue
                    self.db_inserts_map[self.TABLE_TASK_ACTION_TIMERS].append({
                        "name": itask.tdef.name,
                        "cycle": str(itask.point),
                        "ctx_key": json.dumps((ctx_key_0, ctx_key_1)),
                        "ctx": self._namedtuple2json(timer.ctx),
                        "delays": json.dumps(timer.delays),
                        "num": timer.num,
                        "delay": timer.delay,
                        "timeout": timer.timeout})
            if itask.state.time_updated:
                set_args = {
                    "time_updated": itask.state.time_updated,
                    "submit_num": itask.submit_num,
                    "try_num": itask.get_try_num(),
                    "status": itask.state.status}
                where_args = {
                    "cycle": str(itask.point),
                    "name": itask.tdef.name,
                }
                self.db_updates_map.setdefault(self.TABLE_TASK_STATES, [])
                self.db_updates_map[self.TABLE_TASK_STATES].append(
                    (set_args, where_args))
                itask.state.time_updated = None

        self.db_inserts_map[self.TABLE_CHECKPOINT_ID].append({
            # id = -1 for latest
            "id": CylcSuiteDAO.CHECKPOINT_LATEST_ID,
            "time": get_current_time_string(),
            "event": CylcSuiteDAO.CHECKPOINT_LATEST_EVENT})

    def put_insert_task_events(self, itask, args):
        """Put INSERT statement for task_events table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_EVENTS, itask, args)

    def put_insert_task_late_flags(self, itask):
        """If itask is late, put INSERT statement to task_late_flags table."""
        if itask.is_late:
            self._put_insert_task_x(
                CylcSuiteDAO.TABLE_TASK_LATE_FLAGS, itask, {"value": True})

    def put_insert_task_jobs(self, itask, args):
        """Put INSERT statement for task_jobs table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_JOBS, itask, args)

    def put_insert_task_states(self, itask, args):
        """Put INSERT statement for task_states table."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_STATES, itask, args)

    def put_insert_task_outputs(self, itask):
        """Reset custom outputs for a task."""
        self._put_insert_task_x(CylcSuiteDAO.TABLE_TASK_OUTPUTS, itask, {})

    def _put_insert_task_x(self, table_name, itask, args):
        """Put INSERT statement for a task_* table."""
        args.update({
            "name": itask.tdef.name,
            "cycle": str(itask.point)})
        if "submit_num" not in args:
            args["submit_num"] = itask.submit_num
        self.db_inserts_map.setdefault(table_name, [])
        self.db_inserts_map[table_name].append(args)

    def put_update_task_jobs(self, itask, set_args):
        """Put UPDATE statement for task_jobs table."""
        self._put_update_task_x(
            CylcSuiteDAO.TABLE_TASK_JOBS, itask, set_args)

    def put_update_task_outputs(self, itask):
        """Put UPDATE statement for task_outputs table."""
        items = {}
        for trigger, message in itask.state.outputs.get_completed_customs():
            items[trigger] = message
        self._put_update_task_x(
            CylcSuiteDAO.TABLE_TASK_OUTPUTS,
            itask, {"outputs": json.dumps(items)})

    def _put_update_task_x(self, table_name, itask, set_args):
        """Put UPDATE statement for a task_* table."""
        where_args = {
            "cycle": str(itask.point),
            "name": itask.tdef.name}
        if "submit_num" not in set_args:
            where_args["submit_num"] = itask.submit_num
        self.db_updates_map.setdefault(table_name, [])
        self.db_updates_map[table_name].append((set_args, where_args))

    def recover_pub_from_pri(self):
        """Recover public database from private database."""
        if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES:
            self.copy_pri_to_pub()
            LOG.warning(
                "%(pub_db_name)s: recovered from %(pri_db_name)s" % {
                    "pub_db_name": self.pub_dao.db_file_name,
                    "pri_db_name": self.pri_dao.db_file_name})
            self.pub_dao.n_tries = 0

    def restart_upgrade(self):
        """Vacuum/upgrade runtime DB on restart."""
        # Backward compat, upgrade database with state file if necessary
        suite_run_d = os.path.dirname(os.path.dirname(self.pub_path))
        old_pri_db_path = os.path.join(
            suite_run_d, 'state', CylcSuiteDAO.OLD_DB_FILE_BASE_NAME)
        old_pri_db_path_611 = os.path.join(
            suite_run_d, CylcSuiteDAO.OLD_DB_FILE_BASE_NAME_611[0])
        old_state_file_path = os.path.join(suite_run_d, "state", "state")
        if (os.path.exists(old_pri_db_path) and
                os.path.exists(old_state_file_path) and
                not os.path.exists(self.pri_path)):
            # Upgrade pre-6.11.X runtime database + state file
            copy(old_pri_db_path, self.pri_path)
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_with_state_file(old_state_file_path)
            target = os.path.join(suite_run_d, "state.tar.gz")
            cmd = ["tar", "-C", suite_run_d, "-czf", target, "state"]
            if call(cmd, stdin=open(os.devnull)) == 0:
                rmtree(os.path.join(suite_run_d, "state"), ignore_errors=True)
            else:
                try:
                    os.unlink(os.path.join(suite_run_d, "state.tar.gz"))
                except OSError:
                    pass
                ERR.error("cannot tar-gzip + remove old state/ directory")
            # Remove old files as well
            try:
                os.unlink(os.path.join(suite_run_d, "cylc-suite-env"))
            except OSError:
                pass
        elif (os.path.exists(old_pri_db_path_611) and
                not os.path.exists(self.pri_path)):
            # Upgrade 6.11.X runtime database
            os.rename(old_pri_db_path_611, self.pri_path)
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_from_611()
            # Remove old files as well
            for name in [
                    CylcSuiteDAO.OLD_DB_FILE_BASE_NAME_611[1],
                    "cylc-suite-env"]:
                try:
                    os.unlink(os.path.join(suite_run_d, name))
                except OSError:
                    pass
        else:
            pri_dao = self.get_pri_dao()
            pri_dao.upgrade_pickle_to_json()

        # Vacuum the primary/private database file
        pri_dao.vacuum()
        pri_dao.close()
Example #13
0
 def setUp(self):
     self.dao = CylcSuiteDAO(':memory:')
     self.mocked_connection = mock.Mock()
     self.dao.connect = mock.MagicMock(return_value=self.mocked_connection)