예제 #1
0
    def reset_job_manager(self):
        # all jobs have been removed from the JobManager
        self.jm._running_jobs = {}
        self.jm._jobs_by_cell_id = {}
        self.jm = JobManager()

        self.assertEqual(self.jm._running_jobs, {})
        self.assertEqual(self.jm._jobs_by_cell_id, {})
예제 #2
0
 def __init__(self):
     if self._comm is None:
         self._comm = Comm(target_name="KBaseJobs", data={})
         self._comm.on_msg(self._handle_comm_message)
     if self._jm is None:
         self._jm = JobManager()
     if self._msg_map is None:
         self._msg_map = {
             MESSAGE_TYPE["CANCEL"]: self._cancel_jobs,
             MESSAGE_TYPE["CELL_JOB_STATUS"]: self._get_job_states_by_cell_id,
             MESSAGE_TYPE["INFO"]: self._get_job_info,
             MESSAGE_TYPE["LOGS"]: self._get_job_logs,
             MESSAGE_TYPE["RETRY"]: self._retry_jobs,
             MESSAGE_TYPE["START_UPDATE"]: self._modify_job_updates,
             MESSAGE_TYPE["STATUS"]: self._get_job_states,
             MESSAGE_TYPE["STATUS_ALL"]: self._get_all_job_states,
             MESSAGE_TYPE["STOP_UPDATE"]: self._modify_job_updates,
         }
예제 #3
0
    def setUpClass(cls):
        cls.maxDiff = None
        cls.am = AppManager()
        cls.am.reload()  # class uses non-mocked data
        cls.jm = JobManager()
        cls.good_app_id = CONFIG.get("app_tests", "good_app_id")
        cls.good_tag = CONFIG.get("app_tests", "good_app_tag")
        cls.bad_app_id = CONFIG.get("app_tests", "bad_app_id")
        cls.bad_tag = CONFIG.get("app_tests", "bad_app_tag")
        cls.test_app_id = CONFIG.get("app_tests", "test_app_id")
        cls.test_app_version = (
            "056582c691c4df190110b059600d2dc2a3a8b80a"  # where is this coming from?
        )
        cls.test_app_module_name = CONFIG.get("app_tests",
                                              "test_app_module_name")
        cls.test_app_method_name = CONFIG.get("app_tests",
                                              "test_app_method_name")
        cls.test_job_id = CONFIG.get("app_tests", "test_job_id")
        cls.test_tag = CONFIG.get("app_tests", "test_app_tag")
        cls.public_ws = CONFIG.get("app_tests", "public_ws_name")
        cls.ws_id = int(CONFIG.get("app_tests", "public_ws_id"))
        cls.app_input_ref = CONFIG.get("app_tests", "test_input_ref")
        cls.batch_app_id = CONFIG.get("app_tests", "batch_app_id")
        cls.test_viewer_app_id = CONFIG.get("app_tests", "test_viewer_app_id")
        cls.test_app_params = {
            "read_library_names": ["rhodo.art.jgi.reads"],
            "output_contigset_name": "rhodo_contigs",
            "recipe": "auto",
            "assembler": "",
            "pipeline": "",
            "min_contig_len": None,
        }

        cls.expected_app_params = {
            "read_library_refs": ["18836/5/1"],
            "output_contigset_name": "rhodo_contigs",
            "recipe": "auto",
            "assembler": None,
            "pipeline": None,
            "min_contig_len": None,
            "workspace_name": cls.public_ws,
        }
예제 #4
0
 def setUp(self) -> None:
     self.jm = JobManager()
     self.jm.initialize_jobs()
예제 #5
0
class JobManagerTest(unittest.TestCase):
    @classmethod
    @mock.patch(CLIENTS, get_mock_client)
    def setUpClass(cls):
        config = ConfigTests()
        os.environ["KB_WORKSPACE_ID"] = config.get("jobs", "job_test_wsname")
        cls.maxDiff = None

    @mock.patch(CLIENTS, get_mock_client)
    def setUp(self) -> None:
        self.jm = JobManager()
        self.jm.initialize_jobs()

    def reset_job_manager(self):
        # all jobs have been removed from the JobManager
        self.jm._running_jobs = {}
        self.jm._jobs_by_cell_id = {}
        self.jm = JobManager()

        self.assertEqual(self.jm._running_jobs, {})
        self.assertEqual(self.jm._jobs_by_cell_id, {})

    @mock.patch(CLIENTS, get_failing_mock_client)
    def test_initialize_jobs_ee2_fail(self):
        # init jobs should fail. specifically, ee2.check_workspace_jobs should error.
        with self.assertRaisesRegex(NarrativeException,
                                    re.escape("check_workspace_jobs failed")):
            self.jm.initialize_jobs()

    @mock.patch(CLIENTS, get_mock_client)
    def test_initialize_jobs(self):
        self.reset_job_manager()

        # redo the initialise to make sure it worked correctly
        self.jm.initialize_jobs()
        terminal_ids = [
            job_id for job_id, d in self.jm._running_jobs.items()
            if d["job"].was_terminal()
        ]
        self.assertEqual(
            set(TERMINAL_JOBS),
            set(terminal_ids),
        )
        self.assertEqual(set(ALL_JOBS), set(self.jm._running_jobs.keys()))

        for job_id in TERMINAL_IDS:
            self.assertFalse(self.jm._running_jobs[job_id]["refresh"])
        for job_id in NON_TERMINAL_IDS:
            self.assertTrue(self.jm._running_jobs[job_id]["refresh"])

        self.assertEqual(self.jm._jobs_by_cell_id, JOBS_BY_CELL_ID)

    @mock.patch(CLIENTS, get_mock_client)
    def test_initialize_jobs__cell_ids(self):
        """
        Invoke initialize_jobs with cell_ids
        """
        cell_ids = list(JOBS_BY_CELL_ID.keys())
        # Iterate through all combinations of cell IDs
        for combo_len in range(len(cell_ids) + 1):
            for combo in itertools.combinations(cell_ids, combo_len):
                combo = list(combo)
                # Get jobs expected to be associated with the cell IDs
                exp_job_ids = [
                    job_id for cell_id, job_ids in JOBS_BY_CELL_ID.items()
                    for job_id in job_ids if cell_id in combo
                ]
                self.reset_job_manager()
                self.jm.initialize_jobs(cell_ids=combo)

                for job_id, d in self.jm._running_jobs.items():
                    refresh = d["refresh"]

                    self.assertEqual(
                        job_id in exp_job_ids and REFRESH_STATE[job_id],
                        refresh,
                    )

    def test__check_job_list_fail(self):
        with self.assertRaisesRegex(JobRequestException,
                                    f"{JOBS_TYPE_ERR}: {None}"):
            self.jm._check_job_list(None)

        with self.assertRaisesRegex(JobRequestException,
                                    re.escape(f"{JOBS_MISSING_ERR}: {[]}")):
            self.jm._check_job_list([])

        with self.assertRaisesRegex(
                JobRequestException,
                re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}')):
            self.jm._check_job_list(["", "", None])

    def test__check_job_list(self):
        """job list checker"""

        job_a = JOB_CREATED
        job_b = JOB_COMPLETED
        job_c = "job_c"
        job_d = "job_d"
        self.assertEqual(
            self.jm._check_job_list([job_c]),
            (
                [],
                [job_c],
            ),
        )

        self.assertEqual(
            self.jm._check_job_list(
                [job_c, None, "", job_c, job_c, None, job_d]),
            (
                [],
                [job_c, job_d],
            ),
        )

        self.assertEqual(
            self.jm._check_job_list(
                [job_c, None, "", None, job_a, job_a, job_a]),
            (
                [job_a],
                [job_c],
            ),
        )

        self.assertEqual(
            self.jm._check_job_list([None, job_a, None, "", None, job_b]),
            (
                [job_a, job_b],
                [],
            ),
        )

    @mock.patch(CLIENTS, get_mock_client)
    def test__construct_job_output_state_set(self):
        self.assertEqual(
            self.jm._construct_job_output_state_set(ALL_JOBS),
            {
                job_id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id]
                for job_id in ALL_JOBS
            },
        )

    def test__construct_job_output_state_set__empty_list(self):
        self.assertEqual(self.jm._construct_job_output_state_set([]), {})

    @mock.patch(CLIENTS, get_mock_client)
    def test__construct_job_output_state_set__ee2_error(self):
        exc = Exception("Test exception")
        exc_message = str(exc)

        def mock_check_jobs(params):
            raise exc

        with mock.patch.object(MockClients,
                               "check_jobs",
                               side_effect=mock_check_jobs):
            job_states = self.jm._construct_job_output_state_set(ALL_JOBS)

        expected = {
            job_id:
            copy.deepcopy(ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][job_id])
            for job_id in ALL_JOBS
        }

        for job_id in ACTIVE_JOBS:
            # expect there to be an error message added
            expected[job_id]["error"] = exc_message

        self.assertEqual(
            expected,
            job_states,
        )

    def test__create_jobs__empty_list(self):
        self.assertEqual(self.jm._create_jobs([]), {})

    def test__create_jobs__jobs_already_exist(self):
        job_list = self.jm._running_jobs.keys()
        self.assertEqual(self.jm._create_jobs(job_list), {})

    def test__get_job_good(self):
        job_id = ALL_JOBS[0]
        job = self.jm.get_job(job_id)
        self.assertEqual(job_id, job.job_id)
        self.assertIsInstance(job, Job)

    def test__get_job_fail(self):

        inputs = [None, "", JOB_NOT_FOUND]

        for input in inputs:
            with self.assertRaisesRegex(JobRequestException,
                                        f"{JOB_NOT_REG_ERR}: {input}"):
                self.jm.get_job(input)

    @mock.patch(CLIENTS, get_mock_client)
    def test_list_jobs_html(self):
        jobs_html = self.jm.list_jobs()
        self.assertIsInstance(jobs_html, HTML)
        html = jobs_html.data

        counts = {
            "status": {},
            "app_id": {},
            "batch_id": {},
            "user": {},
        }

        n_not_started = 0
        n_incomplete = 0
        for job in TEST_JOBS.values():
            for param in ["status", "user"]:
                if param in job:
                    value = job[param]
                    if value not in counts[param]:
                        counts[param][value] = 0
                    counts[param][value] += 1

            app_id = job["job_input"]["app_id"]
            if app_id not in counts["app_id"]:
                counts["app_id"][app_id] = 0
            counts["app_id"][app_id] += 1

            if "finished" not in job:
                n_incomplete += 1
            if "running" not in job:
                n_not_started += 1

        for job_id in ALL_JOBS:
            self.assertIn(f'<td class="job_id">{job_id}</td>', html)

        for param in counts:
            for value in counts[param]:
                self.assertIn(f'<td class="{param}">{str(value)}</td>', html)
                value_count = html.count(
                    f'<td class="{param}">{str(value)}</td>')

                self.assertEqual(counts[param][value], value_count)

        if n_incomplete:
            incomplete_count = html.count(
                '<td class="finish_time">Incomplete</td>')
            self.assertEqual(incomplete_count, n_incomplete)
        if n_not_started:
            not_started_count = html.count(
                '<td class="run_time">Not started</td>')
            self.assertEqual(not_started_count, n_not_started)

    def test_list_jobs_twice(self):
        # with no jobs
        with mock.patch.object(self.jm, "_running_jobs", {}):
            expected = "No running jobs!"
            self.assertEqual(self.jm.list_jobs(), expected)
            self.assertEqual(self.jm.list_jobs(), expected)

        # with some jobs
        with mock.patch(CLIENTS, get_mock_client):
            jobs_html_0 = self.jm.list_jobs().data
            jobs_html_1 = self.jm.list_jobs().data

            try:
                self.assertEqual(jobs_html_0, jobs_html_1)
            except AssertionError:
                # Sometimes the time is off by a second
                # This will still fail if on the hour
                pattern = r"(\d\d:)\d\d:\d\d"
                sub = r"\1"
                jobs_html_0 = re.sub(pattern, sub, jobs_html_0)
                jobs_html_1 = re.sub(pattern, sub, jobs_html_1)
                self.assertEqual(jobs_html_0, jobs_html_1)

    def test_cancel_jobs__bad_inputs(self):
        with self.assertRaisesRegex(JobRequestException,
                                    re.escape(f"{JOBS_MISSING_ERR}: {[]}")):
            self.jm.cancel_jobs([])

        with self.assertRaisesRegex(
                JobRequestException,
                re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}')):
            self.jm.cancel_jobs(["", "", None])

        job_states = self.jm.cancel_jobs([JOB_NOT_FOUND])
        self.assertEqual(
            {
                JOB_NOT_FOUND:
                ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][JOB_NOT_FOUND]
            },
            job_states,
        )

    def test_cancel_jobs__job_already_finished(self):
        self.assertEqual(get_test_job(JOB_COMPLETED)["status"], "completed")
        self.assertEqual(get_test_job(JOB_TERMINATED)["status"], "terminated")
        self.assertTrue(self.jm.get_job(JOB_COMPLETED).was_terminal())
        self.assertTrue(self.jm.get_job(JOB_TERMINATED).was_terminal())
        job_id_list = [JOB_COMPLETED, JOB_TERMINATED]
        with mock.patch(
                "biokbase.narrative.jobs.jobmanager.JobManager._cancel_job"
        ) as mock_cancel_job:
            canceled_jobs = self.jm.cancel_jobs(job_id_list)
            mock_cancel_job.assert_not_called()
            self.assertEqual(
                {
                    id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
                    for id in job_id_list
                },
                canceled_jobs,
            )

    @mock.patch(CLIENTS, get_mock_client)
    def test_cancel_jobs__run_ee2_cancel_job(self):
        """cancel a set of jobs that run cancel_job on ee2"""
        # jobs list:
        jobs = [
            None,
            JOB_CREATED,
            JOB_RUNNING,
            "",
            JOB_TERMINATED,
            JOB_COMPLETED,
            JOB_TERMINATED,
            None,
            JOB_NOT_FOUND,
        ]

        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
            for id in jobs if id
        }
        self.jm._running_jobs[JOB_RUNNING]["refresh"] = 1
        self.jm._running_jobs[JOB_CREATED]["refresh"] = 1

        def check_state(arg):
            self.assertFalse(self.jm._running_jobs[arg["job_id"]]["refresh"])
            self.assertEqual(self.jm._running_jobs[arg["job_id"]]["canceling"],
                             True)

        # patch MockClients.cancel_job so we can test the input
        with mock.patch.object(
                MockClients,
                "cancel_job",
                mock.Mock(return_value={}, side_effect=check_state),
        ) as mock_cancel_job:
            results = self.jm.cancel_jobs(jobs)
            for job in [JOB_RUNNING, JOB_CREATED]:
                self.assertNotIn("canceling", self.jm._running_jobs[job])
                self.assertEqual(self.jm._running_jobs[job]["refresh"], 1)
            self.assertEqual(results.keys(), expected.keys())
            self.assertEqual(results, expected)
            mock_cancel_job.assert_has_calls(
                [
                    mock.call({"job_id": JOB_RUNNING}),
                    mock.call({"job_id": JOB_CREATED}),
                ],
                any_order=True,
            )

    @mock.patch(CLIENTS, get_mock_client)
    def test_cancel_jobs(self):
        with assert_obj_method_called(self.jm, "cancel_jobs", True):
            self.jm.cancel_jobs([JOB_COMPLETED])

    def _check_retry_jobs(
        self,
        expected,
        retry_results,
    ):
        self.assertEqual(expected, retry_results)
        orig_ids = [
            result["job_id"] for result in retry_results.values()
            if "error" not in result
        ]
        retry_ids = [
            result["retry_id"] for result in retry_results.values()
            if "error" not in result
        ]
        dne_ids = [
            result["job_id"] for result in retry_results.values()
            if result["job_id"] in BAD_JOBS
        ]

        for job_id in orig_ids + retry_ids:
            job = self.jm.get_job(job_id)
            self.assertIn(job_id, self.jm._running_jobs)
            self.assertIsNotNone(job._acc_state)

        for job_id in dne_ids:
            self.assertNotIn(job_id, self.jm._running_jobs)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__success(self):
        job_ids = [BATCH_TERMINATED_RETRIED]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__multi_success(self):
        job_ids = [BATCH_TERMINATED_RETRIED, BATCH_ERROR_RETRIED]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__success_error_dne(self):
        job_ids = [JOB_NOT_FOUND, BATCH_TERMINATED_RETRIED, JOB_COMPLETED]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__all_error(self):
        job_ids = [JOB_COMPLETED, JOB_CREATED, JOB_RUNNING]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__retry_already_terminal(self):
        job_ids = [JOB_COMPLETED]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    @mock.patch(CLIENTS, get_mock_client)
    def test_retry_jobs__none_exist(self):
        job_ids = ["", "", None, BAD_JOB_ID]
        expected = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["RETRY"]][id]
            for id in job_ids if id
        }
        retry_results = self.jm.retry_jobs(job_ids)
        self._check_retry_jobs(expected, retry_results)

    def test_retry_jobs__bad_inputs(self):
        with self.assertRaisesRegex(JobRequestException,
                                    re.escape(f"{JOBS_MISSING_ERR}: {[]}")):
            self.jm.retry_jobs([])

        with self.assertRaisesRegex(
                JobRequestException,
                re.escape(f'{JOBS_MISSING_ERR}: {["", "", None]}')):
            self.jm.retry_jobs(["", "", None])

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_all_job_states(self):
        states = self.jm.get_all_job_states()
        refreshing_jobs = [
            job_id for job_id, state in REFRESH_STATE.items() if state
        ]
        self.assertEqual(set(refreshing_jobs), set(states.keys()))
        self.assertEqual(
            states,
            {
                id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
                for id in refreshing_jobs
            },
        )

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_all_job_states__ignore_refresh_flag(self):
        states = self.jm.get_all_job_states(ignore_refresh_flag=True)
        self.assertEqual(set(ALL_JOBS), set(states.keys()))
        self.assertEqual(
            {
                id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
                for id in ALL_JOBS
            },
            states,
        )

    ## get_job_states_by_cell_id
    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list_None(self):
        with self.assertRaisesRegex(JobRequestException,
                                    CELLS_NOT_PROVIDED_ERR):
            self.jm.get_job_states_by_cell_id(cell_id_list=None)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list_empty(self):
        with self.assertRaisesRegex(JobRequestException,
                                    CELLS_NOT_PROVIDED_ERR):
            self.jm.get_job_states_by_cell_id(cell_id_list=[])

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list_no_results(self):
        result = self.jm.get_job_states_by_cell_id(
            cell_id_list=["a", "b", "c"])
        self.assertEqual(
            {
                "jobs": {},
                "mapping": {
                    "a": set(),
                    "b": set(),
                    "c": set()
                }
            }, result)

    def check_get_job_states_by_cell_id_results(self, cell_ids, expected_ids):
        expected_states = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
            for id in ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]].keys()
            if id in expected_ids
        }
        result = self.jm.get_job_states_by_cell_id(cell_id_list=cell_ids)
        self.assertEqual(set(expected_ids), set(result["jobs"].keys()))
        self.assertEqual(expected_states, result["jobs"])
        self.assertEqual(set(cell_ids), set(result["mapping"].keys()))
        for key in result["mapping"].keys():
            self.assertEqual(set(TEST_CELL_IDs[key]),
                             set(result["mapping"][key]))

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list_all_results(self):
        cell_ids = TEST_CELL_ID_LIST
        self.check_get_job_states_by_cell_id_results(cell_ids, ALL_JOBS)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list__batch_job__one_cell(
            self):
        cell_ids = [TEST_CELL_ID_LIST[2]]
        expected_ids = TEST_CELL_IDs[TEST_CELL_ID_LIST[2]]
        self.check_get_job_states_by_cell_id_results(cell_ids, expected_ids)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list__batch_job__two_cells(
            self):
        cell_ids = [TEST_CELL_ID_LIST[2], TEST_CELL_ID_LIST[3]]
        expected_ids = (TEST_CELL_IDs[TEST_CELL_ID_LIST[2]] +
                        TEST_CELL_IDs[TEST_CELL_ID_LIST[3]])
        self.check_get_job_states_by_cell_id_results(cell_ids, expected_ids)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list__batch_job__one_ok_one_invalid(
        self, ):
        cell_ids = [TEST_CELL_ID_LIST[1], TEST_CELL_ID_LIST[4]]
        expected_ids = TEST_CELL_IDs[TEST_CELL_ID_LIST[1]]
        self.check_get_job_states_by_cell_id_results(cell_ids, expected_ids)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list__batch_and_other_job(
            self):
        cell_ids = [TEST_CELL_ID_LIST[0], TEST_CELL_ID_LIST[2]]
        expected_ids = (TEST_CELL_IDs[TEST_CELL_ID_LIST[0]] +
                        TEST_CELL_IDs[TEST_CELL_ID_LIST[2]])
        self.check_get_job_states_by_cell_id_results(cell_ids, expected_ids)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states_by_cell_id__cell_id_list__batch_in_many_cells(
            self):
        cell_ids = [
            TEST_CELL_ID_LIST[0], TEST_CELL_ID_LIST[2], TEST_CELL_ID_LIST[3]
        ]
        expected_ids = (TEST_CELL_IDs[TEST_CELL_ID_LIST[0]] +
                        TEST_CELL_IDs[TEST_CELL_ID_LIST[2]] +
                        TEST_CELL_IDs[TEST_CELL_ID_LIST[3]])
        self.check_get_job_states_by_cell_id_results(cell_ids, expected_ids)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_states(self):
        job_ids = [
            None,
            None,
            JOB_CREATED,
            JOB_NOT_FOUND,
            JOB_CREATED,
            JOB_RUNNING,
            JOB_TERMINATED,
            JOB_COMPLETED,
            BATCH_PARENT,
            "",
            JOB_NOT_FOUND,
        ]

        exp = {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["STATUS"]][id]
            for id in job_ids if id
        }

        res = self.jm.get_job_states(job_ids)
        self.assertEqual(exp, res)

    def test_get_job_states__empty(self):
        with self.assertRaisesRegex(JobRequestException,
                                    re.escape(f"{JOBS_MISSING_ERR}: {[]}")):
            self.jm.get_job_states([])

    def test_update_batch_job__dne(self):
        with self.assertRaisesRegex(JobRequestException,
                                    f"{JOB_NOT_REG_ERR}: {JOB_NOT_FOUND}"):
            self.jm.update_batch_job(JOB_NOT_FOUND)

    def test_update_batch_job__not_batch(self):
        with self.assertRaisesRegex(JobRequestException,
                                    f"{JOB_NOT_BATCH_ERR}: {JOB_CREATED}"):
            self.jm.update_batch_job(JOB_CREATED)

        with self.assertRaisesRegex(
                JobRequestException,
                f"{JOB_NOT_BATCH_ERR}: {BATCH_TERMINATED}"):
            self.jm.update_batch_job(BATCH_TERMINATED)

    @mock.patch(CLIENTS, get_mock_client)
    def test_update_batch_job__no_change(self):
        job_ids = self.jm.update_batch_job(BATCH_PARENT)
        self.assertEqual(BATCH_PARENT, job_ids[0])
        self.assertCountEqual(BATCH_CHILDREN, job_ids[1:])

    @mock.patch(CLIENTS, get_mock_client)
    def test_update_batch_job__change(self):
        """test child ids having changed"""
        new_child_ids = BATCH_CHILDREN[1:] + [JOB_CREATED, JOB_NOT_FOUND]

        def mock_check_job(params):
            """Called from job.state()"""
            job_id = params["job_id"]
            if job_id == BATCH_PARENT:
                return {"child_jobs": new_child_ids}
            elif job_id in TEST_JOBS:
                return get_test_job(job_id)
            elif job_id == JOB_NOT_FOUND:
                return {
                    "job_id": job_id,
                    "status": generate_error(job_id, "not_found")
                }
            else:
                raise Exception()

        with mock.patch.object(MockClients,
                               "check_job",
                               side_effect=mock_check_job) as m:
            job_ids = self.jm.update_batch_job(BATCH_PARENT)

        m.assert_has_calls([
            mock.call({
                "job_id": BATCH_PARENT,
                "exclude_fields": EXCLUDED_JOB_STATE_FIELDS,
            }),
            mock.call({
                "job_id": JOB_NOT_FOUND,
                "exclude_fields": JOB_INIT_EXCLUDED_JOB_STATE_FIELDS,
            }),
        ])

        self.assertEqual(BATCH_PARENT, job_ids[0])
        self.assertCountEqual(new_child_ids, job_ids[1:])

        batch_job = self.jm.get_job(BATCH_PARENT)
        reg_child_jobs = [
            self.jm.get_job(job_id)
            for job_id in batch_job._acc_state["child_jobs"]
        ]

        self.assertCountEqual(batch_job.children, reg_child_jobs)
        self.assertCountEqual(batch_job._acc_state["child_jobs"],
                              new_child_ids)

        with mock.patch.object(MockClients,
                               "check_job",
                               side_effect=mock_check_job) as m:
            self.assertCountEqual(batch_job.child_jobs, new_child_ids)

    def test_modify_job_refresh(self):
        for job_id, refreshing in REFRESH_STATE.items():
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"],
                             refreshing)
            self.jm.modify_job_refresh([job_id], False)  # stop
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False)
            self.jm.modify_job_refresh([job_id], False)  # stop harder
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False)
            self.jm.modify_job_refresh([job_id], True)  # start
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True)
            self.jm.modify_job_refresh([job_id], True)  # start some more
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"], True)
            self.jm.modify_job_refresh([job_id], False)  # stop
            self.assertEqual(self.jm._running_jobs[job_id]["refresh"], False)

    @mock.patch(CLIENTS, get_mock_client)
    def test_get_job_info(self):
        infos = self.jm.get_job_info(ALL_JOBS)
        self.assertCountEqual(ALL_JOBS, infos.keys())
        self.assertEqual(infos, {
            id: ALL_RESPONSE_DATA[MESSAGE_TYPE["INFO"]][id]
            for id in ALL_JOBS
        })
예제 #6
0
"""
Tests for job management
"""
__author__ = "Bill Riehl <*****@*****.**>"

import unittest
import mock
from biokbase.narrative.jobs.jobmanager import JobManager

jm = JobManager()
jm.initialize_jobs()
예제 #7
0
class JobComm:
    """
    The main JobComm channel. This is the kernel-side of the connection, and routes
    requests for job information from various app cells (or the front end in general)
    to the right function.

    This has a handle on the JobManager, which does the work of fetching job information
    and statuses.

    The JobComm officially exposes the channel for other things to use. Anything that
    needs to send messages about Jobs to the front end should use JobComm.send_comm_message.

    It also maintains the lookup loop thread. This is a threading.Timer that, after
    some interval, will lookup the status of all running jobs. If there are no jobs to
    look up, this cancels itself.

    Allowed messages:
    * job_status - return the job state for a single job (requires a job_id)
    * job_status_all - return job state for all jobs in this Narrative.
    * job_info - return basic job info for a single job (requires a job_id)
    * start_job_update - tells the update loop to include a job when updating (requires a job_id)
    * stop_job_update - has the update loop not include a job when updating (requires a job_id)
    * cancel_job - cancels a running job, if it hasn't otherwise terminated (requires a job_id)
    * retry_job - retries a job (requires a job_id)
    * job_logs - sends job logs back over the comm channel (requires a job id)
    """

    # An instance of this class. It's meant to be a singleton, so this just gets created and
    # returned once.
    __instance = None

    # The kernel job comm channel that talks to the front end.
    _comm = None

    # The JobManager that actually manages things.
    _jm = None

    _msg_map = None
    _running_lookup_loop = False
    _lookup_timer = None
    _log = kblogging.get_logger(__name__)

    def __new__(cls):
        if JobComm.__instance is None:
            JobComm.__instance = object.__new__(cls)
        return JobComm.__instance

    def __init__(self):
        if self._comm is None:
            self._comm = Comm(target_name="KBaseJobs", data={})
            self._comm.on_msg(self._handle_comm_message)
        if self._jm is None:
            self._jm = JobManager()
        if self._msg_map is None:
            self._msg_map = {
                MESSAGE_TYPE["CANCEL"]: self._cancel_jobs,
                MESSAGE_TYPE["CELL_JOB_STATUS"]: self._get_job_states_by_cell_id,
                MESSAGE_TYPE["INFO"]: self._get_job_info,
                MESSAGE_TYPE["LOGS"]: self._get_job_logs,
                MESSAGE_TYPE["RETRY"]: self._retry_jobs,
                MESSAGE_TYPE["START_UPDATE"]: self._modify_job_updates,
                MESSAGE_TYPE["STATUS"]: self._get_job_states,
                MESSAGE_TYPE["STATUS_ALL"]: self._get_all_job_states,
                MESSAGE_TYPE["STOP_UPDATE"]: self._modify_job_updates,
            }

    def _get_job_ids(self, req: JobRequest = None):
        if req.has_batch_id():
            return self._jm.update_batch_job(req.batch_id)

        try:
            return req.job_id_list
        except Exception as ex:
            raise JobRequestException(ONE_INPUT_TYPE_ONLY_ERR) from ex

    def start_job_status_loop(
        self,
        init_jobs: bool = False,
        cell_list: List[str] = None,
    ) -> None:
        """
        Starts the job status lookup loop. This runs every LOOKUP_TIMER_INTERVAL seconds.

        :param init_jobs: If init_jobs=True, this attempts to (re-)initialize
            the JobManager's list of known jobs from the workspace.
        :param cell_list: from FE, the list of extant cell IDs
        """
        self._running_lookup_loop = True
        if init_jobs:
            try:
                self._jm.initialize_jobs(cell_list)
            except Exception as e:
                error = {
                    "error": "Unable to get initial jobs list",
                    "message": getattr(e, "message", UNKNOWN_REASON),
                    "code": getattr(e, "code", -1),
                    "source": getattr(e, "source", "jobmanager"),
                    "name": getattr(e, "name", type(e).__name__),
                }
                self.send_comm_message(MESSAGE_TYPE["ERROR"], error)
                # if job init failed, set the lookup loop var back to False and return
                self._running_lookup_loop = False
                return
        if self._lookup_timer is None:
            self._lookup_job_status_loop()

    def stop_job_status_loop(self, *args, **kwargs) -> None:
        """
        Stops the job status lookup loop if it's running. Otherwise, this effectively
        does nothing.
        """
        if self._lookup_timer:
            self._lookup_timer.cancel()
            self._lookup_timer = None
        self._running_lookup_loop = False

    def _lookup_job_status_loop(self) -> None:
        """
        Run a loop that will look up job info. After running, this spawns a Timer thread on
        a loop to run itself again. LOOKUP_TIMER_INTERVAL sets the frequency at which the loop runs.
        """
        all_job_states = self._get_all_job_states()
        if len(all_job_states) == 0 or not self._running_lookup_loop:
            self.stop_job_status_loop()
        else:
            self._lookup_timer = threading.Timer(
                LOOKUP_TIMER_INTERVAL, self._lookup_job_status_loop
            )
            self._lookup_timer.start()

    def _get_all_job_states(
        self, req: JobRequest = None, ignore_refresh_flag: bool = False
    ) -> dict:
        """
        Fetches status of all jobs in the current workspace and sends them to the front end.
        req can be None, as it's not used.
        """
        all_job_states = self._jm.get_all_job_states(
            ignore_refresh_flag=ignore_refresh_flag
        )
        self.send_comm_message(MESSAGE_TYPE["STATUS_ALL"], all_job_states)
        return all_job_states

    def _get_job_states_by_cell_id(self, req: JobRequest = None) -> dict:
        """
        Fetches status of all jobs associated with the given cell ID(s)
        :param req: a JobRequest with the cell_id_list of interest
        :returns: dict in the form
        {
            "jobs": {
                # dict with job IDs as keys and job states as values
                "job_one": { ... },
                "job_two": { ... },
            },
            "mapping": {
                # dict with cell IDs as keys and values being the set of job IDs associated
                # with that cell
                "cell_one": [ "job_one", "job_two", ... ],
                "cell_two": [ ... ],
            }
        }
        """
        cell_job_states = self._jm.get_job_states_by_cell_id(
            cell_id_list=req.cell_id_list
        )
        self.send_comm_message(MESSAGE_TYPE["CELL_JOB_STATUS"], cell_job_states)
        return cell_job_states

    def _get_job_info(self, req: JobRequest) -> dict:
        """
        Look up job info. This is just some high-level generic information about the running
        job, including the app id, name, and job parameters.
        :param req: a JobRequest with the job_id_list of interest
        :returns: a dict keyed with job IDs and with values of dicts with the following keys:
            - app_id - str - module/name,
            - app_name - str - name of the app as it shows up in the Narrative interface
            - batch_id - str - the batch parent ID (if appropriate)
            - job_id - str - just re-reporting the id string
            - job_params - dict - the params that were passed to that particular job
        """
        job_id_list = self._get_job_ids(req)
        job_info = self._jm.get_job_info(job_id_list)
        self.send_comm_message(MESSAGE_TYPE["INFO"], job_info)
        return job_info

    def __get_job_states(self, job_id_list) -> dict:
        """
        Look up job states.

        Returns a dictionary of job state information indexed by job ID.
        """
        output_states = self._jm.get_job_states(job_id_list)
        self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states)
        return output_states

    def get_job_state(self, job_id: str) -> dict:
        """
        This differs from the _get_job_state (underscored version) in that
        it just takes a job_id string, not a JobRequest.
        """
        return self.__get_job_states([job_id])

    def _get_job_states(self, req: JobRequest) -> dict:
        job_id_list = self._get_job_ids(req)
        return self.__get_job_states(job_id_list)

    def _modify_job_updates(self, req: JobRequest) -> dict:
        """
        Modifies how many things want to listen to a job update.
        If this is a request to start a job update, then this starts the update loop that
        returns update messages across the job channel.
        If this is a request to stop a job update, then this sends that request to the
        JobManager, which might have the side effect of shutting down the update loop if there's
        no longer anything requesting job status.

        If the given job_id in the request doesn't exist in the current Narrative, or is None,
        this raises a JobRequestException.
        """
        job_id_list = self._get_job_ids(req)
        update_type = req.request_type
        if update_type == MESSAGE_TYPE["START_UPDATE"]:
            update_refresh = True
        elif update_type == MESSAGE_TYPE["STOP_UPDATE"]:
            update_refresh = False
        else:
            # this should be impossible
            raise JobRequestException("Unknown request")

        self._jm.modify_job_refresh(job_id_list, update_refresh)

        if update_refresh:
            self.start_job_status_loop()

        output_states = self._jm.get_job_states(job_id_list)
        self.send_comm_message(MESSAGE_TYPE["STATUS"], output_states)
        return output_states

    def _cancel_jobs(self, req: JobRequest) -> dict:
        """
        This cancels a running job.
        If there are no valid jobs, this raises a JobRequestException.
        If there's an error while attempting to cancel, this raises a NarrativeError.
        In the end, after a successful cancel, this finishes up by fetching and returning the
        job state with the new status.
        """
        job_id_list = self._get_job_ids(req)
        cancel_results = self._jm.cancel_jobs(job_id_list)
        self.send_comm_message(MESSAGE_TYPE["STATUS"], cancel_results)
        return cancel_results

    def _retry_jobs(self, req: JobRequest) -> dict:
        job_id_list = self._get_job_ids(req)
        retry_results = self._jm.retry_jobs(job_id_list)
        self.send_comm_message(MESSAGE_TYPE["RETRY"], retry_results)
        return retry_results

    def _get_job_logs(self, req: JobRequest) -> dict:
        """
        This returns a set of job logs based on the info in the request.
        """
        job_id_list = self._get_job_ids(req)
        log_output = self._jm.get_job_logs_for_list(
            job_id_list,
            num_lines=req.rq_data.get("num_lines", None),
            first_line=req.rq_data.get("first_line", 0),
            latest=req.rq_data.get("latest", False),
        )
        self.send_comm_message(MESSAGE_TYPE["LOGS"], log_output)
        return log_output

    def _handle_comm_message(self, msg: dict) -> dict:
        """
        Handles comm messages that come in from the other end of the KBaseJobs channel.
        Messages get translated into one or more JobRequest objects, which are then
        passed to the right handler, based on the request.

        A handler dictionary is created on JobComm creation.

        Any unknown request is returned over the channel with message type 'job_error', and a
        JobRequestException is raised.
        """
        with exc_to_msg(msg):
            request = JobRequest(msg)

            kblogging.log_event(
                self._log, "handle_comm_message", {"msg": request.request_type}
            )
            if request.request_type not in self._msg_map:
                raise JobRequestException(
                    f"Unknown KBaseJobs message '{request.request_type}'"
                )

            return self._msg_map[request.request_type](request)

    def send_comm_message(self, msg_type: str, content: dict) -> None:
        """
        Sends a ipykernel.Comm message to the KBaseJobs channel with the given msg_type
        and content. These just get encoded into the message itself.
        """
        msg = {"msg_type": msg_type, "content": content}
        self._comm.send(msg)

    def send_error_message(
        self, req: Union[JobRequest, dict, str], content: dict = None
    ) -> None:
        """
        Sends a comm message over the KBaseJobs channel as an error. This will have msg_type set to
        ERROR ('job_error'), and include the original request in the message content as
        "source".

        req can be the original request message or its JobRequest form.
        Since the latter is made from the former, they have the same information.
        It can also be a string or None if this context manager is invoked outside of a JC request

        This sends a packet that looks like:
        {
            request: the original JobRequest data object, function params, or function name
            source: the function request that spawned the error
            other fields about the error, dependent on the content.
        }
        """
        error_content = {}
        if isinstance(req, JobRequest):
            error_content["request"] = req.rq_data
            error_content["source"] = req.request_type
        elif isinstance(req, dict):
            data = req.get("content", {}).get("data", {})
            error_content["request"] = data
            error_content["source"] = data.get("request_type")
        elif isinstance(req, str) or req is None:
            error_content["request"] = req
            error_content["source"] = req

        if content is not None:
            error_content.update(content)

        self.send_comm_message(MESSAGE_TYPE["ERROR"], error_content)
예제 #8
0
"""
Tests for job management
"""
__author__ = "Bill Riehl <*****@*****.**>"

import unittest
import mock
from biokbase.narrative.jobs.jobmanager import JobManager

jm = JobManager()
jm.initialize_jobs()