Example #1
0
    def schedule(self) -> task_scheduler.TaskSchedulerResult:
        def _as_dict(proto_map) -> Dict[str, types.Property]:
            return {
                k: data_types_utils.get_value(v)
                for k, v in proto_map.items()
            }

        pipeline_node = self.task.get_pipeline_node()
        output_spec = pipeline_node.outputs.outputs[importer.IMPORT_RESULT_KEY]
        properties = _as_dict(output_spec.artifact_spec.additional_properties)
        custom_properties = _as_dict(
            output_spec.artifact_spec.additional_custom_properties)

        output_artifacts = importer.generate_output_dict(
            metadata_handler=self.mlmd_handle,
            uri=str(self.task.exec_properties[importer.SOURCE_URI_KEY]),
            properties=properties,
            custom_properties=custom_properties,
            reimport=bool(
                self.task.exec_properties[importer.REIMPORT_OPTION_KEY]),
            output_artifact_class=types.Artifact(
                output_spec.artifact_spec.type).type,
            mlmd_artifact_type=output_spec.artifact_spec.type)

        return task_scheduler.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.OK),
            output=task_scheduler.ImporterNodeOutput(
                output_artifacts=output_artifacts))
Example #2
0
 def _process_exec_node_task(self, scheduler: ts.TaskScheduler,
                             task: task_lib.ExecNodeTask) -> None:
     """Processes an `ExecNodeTask` using the given task scheduler."""
     # This is a blocking call to the scheduler which can take a long time to
     # complete for some types of task schedulers. The scheduler is expected to
     # handle any internal errors gracefully and return the result with an error
     # status. But in case the scheduler raises an exception, it is considered
     # a failed execution and MLMD is updated accordingly.
     try:
         result = scheduler.schedule()
     except Exception as e:  # pylint: disable=broad-except
         logging.exception(
             'Exception raised by task scheduler; node uid: %s',
             task.node_uid)
         result = ts.TaskSchedulerResult(status=status_lib.Status(
             code=status_lib.Code.ABORTED, message=str(e)))
     logging.info(
         'For ExecNodeTask id: %s, task-scheduler result status: %s',
         task.task_id, result.status)
     _publish_execution_results(mlmd_handle=self._mlmd_handle,
                                task=task,
                                result=result)
     with self._publish_time_lock:
         self._last_mlmd_publish_time = time.time()
     with self._tm_lock:
         del self._scheduler_by_node_uid[task.node_uid]
         self._task_queue.task_done(task)
Example #3
0
    def test_executor_failure(self):
        # Register a fake task scheduler that returns success but the executor
        # was cancelled.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(
                status=status_lib.Status(code=status_lib.Code.OK),
                output=ts.ExecutorNodeOutput(
                    executor_output=_make_executor_output(
                        self._task,
                        code=status_lib.Code.FAILED_PRECONDITION,
                        msg='foobar error'))))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked failed.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.FAILED,
                         execution.last_known_state)
        self.assertEqual(
            'foobar error',
            data_types_utils.get_metadata_value(execution.custom_properties[
                constants.EXECUTION_ERROR_MSG_KEY]))

        # Check that stateful working dir, tmp_dir and output artifact URI are
        # removed.
        self.assertFalse(os.path.exists(self._task.stateful_working_dir))
        self.assertFalse(os.path.exists(self._task.tmp_dir))
        self.assertFalse(os.path.exists(self._output_artifact_uri))
Example #4
0
 def schedule(self):
     logging.info('_FakeTaskScheduler: scheduling task: %s', self.task)
     self._collector.add_scheduled_task(self.task)
     if self.task.node_uid.node_id in self._block_nodes:
         self._stop_event.wait()
     return ts.TaskSchedulerResult(
         executor_output=execution_result_pb2.ExecutorOutput())
Example #5
0
    def schedule(self) -> task_scheduler.TaskSchedulerResult:
        while not self._cancel.wait(_POLLING_INTERVAL_SECS):
            with mlmd_state.mlmd_execution_atomic_op(
                    mlmd_handle=self.mlmd_handle,
                    execution_id=self.task.execution_id) as execution:
                node_state_mlmd_value = execution.custom_properties.get(
                    NODE_STATE_PROPERTY_KEY)
                node_state = ManualNodeState.from_mlmd_value(
                    node_state_mlmd_value)
                if node_state.state == ManualNodeState.COMPLETED:
                    return task_scheduler.TaskSchedulerResult(
                        status=status_lib.Status(code=status_lib.Code.OK),
                        output=task_scheduler.ExecutorNodeOutput())

        return task_scheduler.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.CANCELLED),
            output=task_scheduler.ExecutorNodeOutput())
Example #6
0
 def schedule(self):
     logging.info('_FakeTaskScheduler: scheduling task: %s', self.task)
     self._collector.add_scheduled_task(self.task)
     if self.task.node_uid.node_id in self._block_nodes:
         self._cancel.wait()
         code = status_lib.Code.CANCELLED
     else:
         code = status_lib.Code.OK
     return ts.TaskSchedulerResult(status=status_lib.Status(
         code=code, message='_FakeTaskScheduler result'))
Example #7
0
 def schedule(self) -> ts.TaskSchedulerResult:
     logging.info('Processing ExecNodeTask: %s', self.task)
     executor_output = execution_result_pb2.ExecutorOutput()
     executor_output.execution_result.code = status_lib.Code.OK
     for key, artifacts in self.task.output_artifacts.items():
         for artifact in artifacts:
             executor_output.output_artifacts[key].artifacts.add().CopyFrom(
                 artifact.mlmd_artifact)
     result = ts.TaskSchedulerResult(
         status=status_lib.Status(code=status_lib.Code.OK),
         output=ts.ExecutorNodeOutput(executor_output=executor_output))
     logging.info('Result: %s', result)
     return result
Example #8
0
    def test_exceptions_are_surfaced(self, mock_publish):
        def _publish(**kwargs):
            task = kwargs['task']
            assert task_lib.is_exec_node_task(task)
            if task.node_uid.node_id == 'Transform':
                raise ValueError('test error')
            return mock.DEFAULT

        mock_publish.side_effect = _publish

        collector = _Collector()

        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            self._type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={},
                              collector=collector))

        task_queue = tq.TaskQueue()

        with self._task_manager(task_queue) as task_manager:
            transform_task = _test_exec_node_task('Transform',
                                                  'test-pipeline',
                                                  pipeline=self._pipeline)
            trainer_task = _test_exec_node_task('Trainer',
                                                'test-pipeline',
                                                pipeline=self._pipeline)
            task_queue.enqueue(transform_task)
            task_queue.enqueue(trainer_task)

        self.assertTrue(task_manager.done())
        exception = task_manager.exception()
        self.assertIsNotNone(exception)
        self.assertIsInstance(exception, tm.TasksProcessingError)
        self.assertLen(exception.errors, 1)
        self.assertEqual('test error', str(exception.errors[0]))

        self.assertCountEqual([transform_task, trainer_task],
                              collector.scheduled_tasks)
        result_ok = ts.TaskSchedulerResult(status=status_lib.Status(
            code=status_lib.Code.OK, message='_FakeTaskScheduler result'))
        mock_publish.assert_has_calls([
            mock.call(
                mlmd_handle=mock.ANY, task=transform_task, result=result_ok),
            mock.call(
                mlmd_handle=mock.ANY, task=trainer_task, result=result_ok),
        ],
                                      any_order=True)
Example #9
0
  def test_scheduler_failure(self):
    # Register a fake task scheduler that returns a failure status.
    self._register_task_scheduler(
        ts.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.ABORTED),
            executor_output=None))
    task_manager = self._run_task_manager()
    self.assertTrue(task_manager.done())
    self.assertIsNone(task_manager.exception())

    # Check that the task was processed and MLMD execution marked failed.
    self.assertTrue(self._task_queue.is_empty())
    execution = self._get_execution()
    self.assertEqual(metadata_store_pb2.Execution.FAILED,
                     execution.last_known_state)
Example #10
0
    def test_successful_execution_resulting_in_output_artifacts(self):
        # Register a fake task scheduler that returns a successful execution result
        # and `OK` task scheduler status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(
                status=status_lib.Status(code=status_lib.Code.OK),
                output_artifacts=self._task.output_artifacts))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked successful.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
                         execution.last_known_state)
Example #11
0
    def test_scheduler_failure(self):
        # Register a fake task scheduler that returns a failure status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(status=status_lib.Status(
                code=status_lib.Code.ABORTED, message='foobar error'),
                                   executor_output=None))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked failed.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.FAILED,
                         execution.last_known_state)
        self.assertEqual(
            'foobar error',
            data_types_utils.get_metadata_value(execution.custom_properties[
                constants.EXECUTION_ERROR_MSG_KEY]))
Example #12
0
    def test_successful_execution_resulting_in_output_artifacts(self):
        # Register a fake task scheduler that returns a successful execution result
        # and `OK` task scheduler status.
        self._register_task_scheduler(
            ts.TaskSchedulerResult(
                status=status_lib.Status(code=status_lib.Code.OK),
                output=ts.ImporterNodeOutput(
                    output_artifacts=self._task.output_artifacts)))
        task_manager = self._run_task_manager()
        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Check that the task was processed and MLMD execution marked successful.
        self.assertTrue(self._task_queue.is_empty())
        execution = self._get_execution()
        self.assertEqual(metadata_store_pb2.Execution.COMPLETE,
                         execution.last_known_state)

        # Check that stateful working dir and tmp_dir are removed.
        self.assertFalse(os.path.exists(self._task.stateful_working_dir))
        self.assertFalse(os.path.exists(self._task.tmp_dir))
        # Output artifact URI remains as execution was successful.
        self.assertTrue(os.path.exists(self._output_artifact_uri))
Example #13
0
 def schedule(self) -> task_scheduler.TaskSchedulerResult:
     return task_scheduler.TaskSchedulerResult(
         status=status_lib.Status(code=status_lib.Code.OK),
         output=task_scheduler.ResolverNodeOutput(
             resolved_input_artifacts=self.task.input_artifacts))
Example #14
0
 def schedule(self):
   return ts.TaskSchedulerResult(
       executor_output=execution_result_pb2.ExecutorOutput())
Example #15
0
    def test_task_handling(self, mock_publish):
        collector = _Collector()

        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(
            self._type_url,
            functools.partial(_FakeTaskScheduler,
                              block_nodes={'Trainer', 'Transform', 'Pusher'},
                              collector=collector))

        task_queue = tq.TaskQueue()

        # Enqueue some tasks.
        trainer_exec_task = _test_exec_node_task('Trainer',
                                                 'test-pipeline',
                                                 pipeline=self._pipeline)
        task_queue.enqueue(trainer_exec_task)
        task_queue.enqueue(_test_cancel_node_task('Trainer', 'test-pipeline'))

        with self._task_manager(task_queue) as task_manager:
            # Enqueue more tasks after task manager starts.
            transform_exec_task = _test_exec_node_task('Transform',
                                                       'test-pipeline',
                                                       pipeline=self._pipeline)
            task_queue.enqueue(transform_exec_task)
            evaluator_exec_task = _test_exec_node_task('Evaluator',
                                                       'test-pipeline',
                                                       pipeline=self._pipeline)
            task_queue.enqueue(evaluator_exec_task)
            task_queue.enqueue(
                _test_cancel_node_task('Transform', 'test-pipeline'))
            pusher_exec_task = _test_exec_node_task('Pusher',
                                                    'test-pipeline',
                                                    pipeline=self._pipeline)
            task_queue.enqueue(pusher_exec_task)
            task_queue.enqueue(
                _test_cancel_node_task('Pusher', 'test-pipeline', pause=True))

        self.assertTrue(task_manager.done())
        self.assertIsNone(task_manager.exception())

        # Ensure that all exec and cancellation tasks were processed correctly.
        self.assertCountEqual([
            trainer_exec_task,
            transform_exec_task,
            evaluator_exec_task,
            pusher_exec_task,
        ], collector.scheduled_tasks)
        self.assertCountEqual([
            trainer_exec_task,
            transform_exec_task,
            pusher_exec_task,
        ], collector.cancelled_tasks)

        result_ok = ts.TaskSchedulerResult(status=status_lib.Status(
            code=status_lib.Code.OK, message='_FakeTaskScheduler result'))
        result_cancelled = ts.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.CANCELLED,
                                     message='_FakeTaskScheduler result'))
        mock_publish.assert_has_calls([
            mock.call(mlmd_handle=mock.ANY,
                      task=trainer_exec_task,
                      result=result_cancelled),
            mock.call(mlmd_handle=mock.ANY,
                      task=transform_exec_task,
                      result=result_cancelled),
            mock.call(mlmd_handle=mock.ANY,
                      task=evaluator_exec_task,
                      result=result_ok),
        ],
                                      any_order=True)
        # It is expected that publish is not called for Pusher because it was
        # cancelled with pause=True so there must be only 3 calls.
        self.assertLen(mock_publish.mock_calls, 3)