Esempio n. 1
0
 def __enter__(self) -> 'PipelineState':
     mlmd_execution_atomic_op_context = mlmd_state.mlmd_execution_atomic_op(
         self.mlmd_handle, self.execution_id)
     execution = mlmd_execution_atomic_op_context.__enter__()
     self._mlmd_execution_atomic_op_context = mlmd_execution_atomic_op_context
     self._execution = execution
     return self
Esempio n. 2
0
 def test_mlmd_execution_update(self):
     with self._mlmd_connection as m:
         expected_execution = _write_test_execution(m)
         # Mutate execution.
         with mlmd_state.mlmd_execution_atomic_op(
                 m, expected_execution.id) as execution:
             self.assertEqual(expected_execution, execution)
             execution.last_known_state = metadata_store_pb2.Execution.CANCELED
         # Test that updated execution is committed to MLMD.
         [execution] = m.store.get_executions_by_id([execution.id])
         self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                          execution.last_known_state)
         # Test that in-memory state is also in sync.
         with mlmd_state.mlmd_execution_atomic_op(
                 m, expected_execution.id) as execution:
             self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                              execution.last_known_state)
Esempio n. 3
0
    def test_pipeline_failure_strategies(self, fail_fast):
        """Tests pipeline failure strategies."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        self._run_next(False,
                       expect_nodes=[self._stats_gen],
                       fail_fast=fail_fast)
        self._run_next(False,
                       expect_nodes=[self._schema_gen],
                       fail_fast=fail_fast)

        # Both example-validator and transform are ready to execute.
        [example_validator_task,
         transform_task] = self._generate(False, True, fail_fast=fail_fast)
        self.assertEqual(self._example_validator.node_info.id,
                         example_validator_task.node_uid.node_id)
        self.assertEqual(self._transform.node_info.id,
                         transform_task.node_uid.node_id)

        # Simulate Transform success.
        self._finish_node_execution(False, transform_task)

        # But fail example-validator.
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, example_validator_task.execution_id) as ev_exec:
                # Fail stats-gen execution.
                ev_exec.last_known_state = metadata_store_pb2.Execution.FAILED
                data_types_utils.set_metadata_value(
                    ev_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY],
                    'example-validator error')

        if fail_fast:
            # Pipeline run should immediately fail because example-validator failed.
            [finalize_task] = self._generate(False, True, fail_fast=fail_fast)
            self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
            self.assertEqual(status_lib.Code.ABORTED,
                             finalize_task.status.code)
        else:
            # Trainer and downstream nodes can execute as transform has finished.
            # example-validator failure does not impact them as it is not upstream.
            # Pipeline run will still fail but when no more progress can be made.
            self._run_next(False,
                           expect_nodes=[self._trainer],
                           fail_fast=fail_fast)
            self._run_next(False,
                           expect_nodes=[self._chore_a],
                           fail_fast=fail_fast)
            self._run_next(False,
                           expect_nodes=[self._chore_b],
                           fail_fast=fail_fast)
            [finalize_task] = self._generate(False, True, fail_fast=fail_fast)
            self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
            self.assertEqual(status_lib.Code.ABORTED,
                             finalize_task.status.code)
Esempio n. 4
0
  def test_triggering_upon_exec_properties_change(self):
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    [exec_transform_task] = self._generate_and_test(
        False,
        num_initial_executions=1,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._transform],
        ignore_update_node_state_tasks=True)

    # Fail the registered execution.
    with self._mlmd_connection as m:
      with mlmd_state.mlmd_execution_atomic_op(
          m, exec_transform_task.execution_id) as execution:
        execution.last_known_state = metadata_store_pb2.Execution.FAILED

    # Try to generate with same execution properties. This should not trigger
    # as there are no changes since last run.
    self._generate_and_test(
        False,
        num_initial_executions=2,
        num_tasks_generated=0,
        num_new_executions=0,
        num_active_executions=0,
        ignore_update_node_state_tasks=True)

    # Change execution properties of last run.
    with self._mlmd_connection as m:
      with mlmd_state.mlmd_execution_atomic_op(
          m, exec_transform_task.execution_id) as execution:
        execution.custom_properties['a_param'].int_value = 20

    # Generating with different execution properties should trigger.
    self._generate_and_test(
        False,
        num_initial_executions=2,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._transform],
        ignore_update_node_state_tasks=True)
Esempio n. 5
0
def _update_execution_state_in_mlmd(
        mlmd_handle: metadata.Metadata, execution_id: int,
        new_state: metadata_store_pb2.Execution.State, error_msg: str) -> None:
    with mlmd_state.mlmd_execution_atomic_op(mlmd_handle,
                                             execution_id) as execution:
        execution.last_known_state = new_state
        if error_msg:
            data_types_utils.set_metadata_value(
                execution.custom_properties[constants.EXECUTION_ERROR_MSG_KEY],
                error_msg)
Esempio n. 6
0
def resume_manual_node(mlmd_handle: metadata.Metadata,
                       node_uid: task_lib.NodeUid) -> None:
    """Resumes a manual node.

  Args:
    mlmd_handle: A handle to the MLMD db.
    node_uid: Uid of the manual node to be resumed.

  Raises:
    status_lib.StatusNotOkError: Failure to resume a manual node.
  """
    logging.info('Received request to resume manual node; node uid: %s',
                 node_uid)
    with pstate.PipelineState.load(mlmd_handle,
                                   node_uid.pipeline_uid) as pipeline_state:
        nodes = pstate.get_all_pipeline_nodes(pipeline_state.pipeline)
        filtered_nodes = [
            n for n in nodes if n.node_info.id == node_uid.node_id
        ]
        if len(filtered_nodes) != 1:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.NOT_FOUND,
                message=(f'Unable to find manual node to resume: {node_uid}'))
        node = filtered_nodes[0]
        node_type = node.node_info.type.name
        if node_type != constants.MANUAL_NODE_TYPE:
            raise status_lib.StatusNotOkError(
                code=status_lib.Code.INVALID_ARGUMENT,
                message=('Unable to resume a non-manual node. '
                         f'Got non-manual node id: {node_uid}'))

    executions = task_gen_utils.get_executions(mlmd_handle, node)
    active_executions = [
        e for e in executions if execution_lib.is_execution_active(e)
    ]
    if not active_executions:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.NOT_FOUND,
            message=(
                f'Unable to find active manual node to resume: {node_uid}'))
    if len(active_executions) > 1:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.INTERNAL,
            message=(f'Unexpected multiple active executions for manual node: '
                     f'{node_uid}'))
    with mlmd_state.mlmd_execution_atomic_op(
            mlmd_handle=mlmd_handle,
            execution_id=active_executions[0].id) as execution:
        completed_state = manual_task_scheduler.ManualNodeState(
            state=manual_task_scheduler.ManualNodeState.COMPLETED)
        completed_state.set_mlmd_value(
            execution.custom_properties.get_or_create(
                manual_task_scheduler.NODE_STATE_PROPERTY_KEY))
Esempio n. 7
0
    def test_restart_node_cancelled_due_to_stopping(self):
        """Tests that a node previously cancelled due to stopping can be restarted."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [stats_gen_task
         ] = self._generate_and_test(False,
                                     num_initial_executions=1,
                                     num_tasks_generated=1,
                                     num_new_executions=1,
                                     num_active_executions=1,
                                     ignore_update_node_state_tasks=True)
        node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline,
                                                       self._stats_gen)
        self.assertEqual(node_uid, stats_gen_task.node_uid)

        # Simulate stopping the node while it is under execution, which leads to
        # the node execution being cancelled.
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, stats_gen_task.execution_id) as stats_gen_exec:
                stats_gen_exec.last_known_state = metadata_store_pb2.Execution.CANCELED
                data_types_utils.set_metadata_value(
                    stats_gen_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY], 'manually stopped')

        # Change state of node to STARTING.
        with self._mlmd_connection as m:
            pipeline_state = test_utils.get_or_create_pipeline_state(
                m, self._pipeline)
            with pipeline_state:
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)

        # New execution should be created for any previously canceled node when the
        # node state is STARTING.
        [update_node_state_task,
         stats_gen_task] = self._generate_and_test(False,
                                                   num_initial_executions=2,
                                                   num_tasks_generated=2,
                                                   num_new_executions=1,
                                                   num_active_executions=1)
        self.assertTrue(
            task_lib.is_update_node_state_task(update_node_state_task))
        self.assertEqual(node_uid, update_node_state_task.node_uid)
        self.assertEqual(pstate.NodeState.RUNNING,
                         update_node_state_task.state)
        self.assertEqual(node_uid, stats_gen_task.node_uid)
Esempio n. 8
0
 def test_mlmd_execution_update(self):
     event_on_commit = threading.Event()
     with self._mlmd_connection as m:
         expected_execution = _write_test_execution(m)
         # Mutate execution.
         with mlmd_state.mlmd_execution_atomic_op(
                 m, expected_execution.id,
                 on_commit=event_on_commit.set) as execution:
             self.assertEqual(expected_execution, execution)
             execution.last_known_state = metadata_store_pb2.Execution.CANCELED
             self.assertFalse(event_on_commit.is_set())  # not yet invoked.
         # Test that updated execution is committed to MLMD.
         [execution] = m.store.get_executions_by_id([execution.id])
         self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                          execution.last_known_state)
         # Test that in-memory state is also in sync.
         self.assertEqual(execution,
                          mlmd_state._execution_cache._cache[execution.id])
         # Test that on_commit callback was invoked.
         self.assertTrue(event_on_commit.is_set())
         # Sanity checks that the updated execution is yielded in the next call.
         with mlmd_state.mlmd_execution_atomic_op(
                 m, expected_execution.id) as execution2:
             self.assertEqual(execution, execution2)
Esempio n. 9
0
  def test_triggering_upon_executor_spec_change(self):
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    with mock.patch.object(task_gen_utils,
                           'get_executor_spec') as mock_get_executor_spec:
      mock_get_executor_spec.side_effect = _fake_executor_spec(1)
      [exec_transform_task] = self._generate_and_test(
          False,
          num_initial_executions=1,
          num_tasks_generated=1,
          num_new_executions=1,
          num_active_executions=1,
          expected_exec_nodes=[self._transform],
          ignore_update_node_state_tasks=True)

    # Fail the registered execution.
    with self._mlmd_connection as m:
      with mlmd_state.mlmd_execution_atomic_op(
          m, exec_transform_task.execution_id) as execution:
        execution.last_known_state = metadata_store_pb2.Execution.FAILED

    # Try to generate with same executor spec. This should not trigger as
    # there are no changes since last run.
    with mock.patch.object(task_gen_utils,
                           'get_executor_spec') as mock_get_executor_spec:
      mock_get_executor_spec.side_effect = _fake_executor_spec(1)
      self._generate_and_test(
          False,
          num_initial_executions=2,
          num_tasks_generated=0,
          num_new_executions=0,
          num_active_executions=0,
          ignore_update_node_state_tasks=True)

    # Generating with a different executor spec should trigger.
    with mock.patch.object(task_gen_utils,
                           'get_executor_spec') as mock_get_executor_spec:
      mock_get_executor_spec.side_effect = _fake_executor_spec(2)
      self._generate_and_test(
          False,
          num_initial_executions=2,
          num_tasks_generated=1,
          num_new_executions=1,
          num_active_executions=1,
          expected_exec_nodes=[self._transform],
          ignore_update_node_state_tasks=True)
Esempio n. 10
0
    def schedule(self) -> task_scheduler.TaskSchedulerResult:
        while not self._cancel.wait(_POLLING_INTERVAL_SECS):
            with mlmd_state.mlmd_execution_atomic_op(
                    mlmd_handle=self.mlmd_handle,
                    execution_id=self.task.execution_id) as execution:
                node_state_mlmd_value = execution.custom_properties.get(
                    NODE_STATE_PROPERTY_KEY)
                node_state = ManualNodeState.from_mlmd_value(
                    node_state_mlmd_value)
                if node_state.state == ManualNodeState.COMPLETED:
                    return task_scheduler.TaskSchedulerResult(
                        status=status_lib.Status(code=status_lib.Code.OK),
                        output=task_scheduler.ExecutorNodeOutput())

        return task_scheduler.TaskSchedulerResult(
            status=status_lib.Status(code=status_lib.Code.CANCELLED),
            output=task_scheduler.ExecutorNodeOutput())
Esempio n. 11
0
    def test_node_failed(self, fail_fast):
        """Tests task generation when a node registers a failed execution."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [stats_gen_task
         ] = self._generate_and_test(False,
                                     num_initial_executions=1,
                                     num_tasks_generated=1,
                                     num_new_executions=1,
                                     num_active_executions=1,
                                     ignore_update_node_state_tasks=True,
                                     fail_fast=fail_fast)
        self.assertEqual(
            task_lib.NodeUid.from_pipeline_node(self._pipeline,
                                                self._stats_gen),
            stats_gen_task.node_uid)
        with self._mlmd_connection as m:
            with mlmd_state.mlmd_execution_atomic_op(
                    m, stats_gen_task.execution_id) as stats_gen_exec:
                # Fail stats-gen execution.
                stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED
                data_types_utils.set_metadata_value(
                    stats_gen_exec.custom_properties[
                        constants.EXECUTION_ERROR_MSG_KEY], 'foobar error')

        # Test generation of FinalizePipelineTask.
        [update_node_state_task,
         finalize_task] = self._generate_and_test(True,
                                                  num_initial_executions=2,
                                                  num_tasks_generated=2,
                                                  num_new_executions=0,
                                                  num_active_executions=0,
                                                  fail_fast=fail_fast)
        self.assertTrue(
            task_lib.is_update_node_state_task(update_node_state_task))
        self.assertEqual('my_statistics_gen',
                         update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.FAILED, update_node_state_task.state)
        self.assertRegexMatch(update_node_state_task.status.message,
                              ['foobar error'])
        self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
        self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code)
        self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
Esempio n. 12
0
  def test_node_failed(self, use_task_queue):
    """Tests task generation when a node registers a failed execution."""
    otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1)

    def _ensure_node_services(unused_pipeline_state, node_id):
      self.assertEqual(self._example_gen.node_info.id, node_id)
      return service_jobs.ServiceStatus.SUCCESS

    self._mock_service_job_manager.ensure_node_services.side_effect = (
        _ensure_node_services)

    [stats_gen_task] = self._generate_and_test(
        use_task_queue,
        num_initial_executions=1,
        num_tasks_generated=1,
        num_new_executions=1,
        num_active_executions=1)
    self.assertEqual(
        task_lib.NodeUid.from_pipeline_node(self._pipeline, self._stats_gen),
        stats_gen_task.node_uid)
    with self._mlmd_connection as m:
      with mlmd_state.mlmd_execution_atomic_op(
          m, stats_gen_task.execution_id) as stats_gen_exec:
        # Fail stats-gen execution.
        stats_gen_exec.last_known_state = metadata_store_pb2.Execution.FAILED
        data_types_utils.set_metadata_value(
            stats_gen_exec.custom_properties[constants.EXECUTION_ERROR_MSG_KEY],
            'foobar error')
    if use_task_queue:
      task = self._task_queue.dequeue()
      self._task_queue.task_done(task)

    # Test generation of FinalizePipelineTask.
    [finalize_task] = self._generate_and_test(
        True,
        num_initial_executions=2,
        num_tasks_generated=1,
        num_new_executions=0,
        num_active_executions=0)
    self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
    self.assertEqual(status_lib.Code.ABORTED, finalize_task.status.code)
    self.assertRegexMatch(finalize_task.status.message, ['foobar error'])
Esempio n. 13
0
 def test_mlmd_execution_absent(self):
     with self._mlmd_connection as m:
         with self.assertRaisesRegex(
                 ValueError, 'Execution not found for execution id'):
             with mlmd_state.mlmd_execution_atomic_op(m, 1):
                 pass
Esempio n. 14
0
 def test_mlmd_execution_absent(self):
     with self._mlmd_connection as m:
         with mlmd_state.mlmd_execution_atomic_op(m, 1) as execution:
             self.assertIsNone(execution)