示例#1
0
    def _run_next(self,
                  use_task_queue,
                  expect_nodes,
                  finish_nodes=None,
                  artifact_custom_properties=None,
                  fail_fast=False):
        """Runs a complete cycle of task generation and simulating their completion.

    Args:
      use_task_queue: Whether to use task queue.
      expect_nodes: List of nodes whose task generation is expected.
      finish_nodes: List of nodes whose completion should be simulated. If
        `None` (default), all of `expect_nodes` will be finished.
      artifact_custom_properties: A dict of custom properties to attach to the
        output artifacts.
      fail_fast: If `True`, pipeline is aborted immediately if any node fails.
    """
        tasks = self._generate(use_task_queue, True, fail_fast=fail_fast)
        for task in tasks:
            self.assertTrue(task_lib.is_exec_node_task(task))
        expected_node_ids = [n.node_info.id for n in expect_nodes]
        task_node_ids = [task.node_uid.node_id for task in tasks]
        self.assertCountEqual(expected_node_ids, task_node_ids)
        finish_node_ids = set([n.node_info.id for n in finish_nodes] if
                              finish_nodes is not None else expected_node_ids)
        for task in tasks:
            if task.node_uid.node_id in finish_node_ids:
                self._finish_node_execution(
                    use_task_queue,
                    task,
                    artifact_custom_properties=artifact_custom_properties)
示例#2
0
    def test_service_job_success(self):
        """Tests task generation when example-gen service job succeeds."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [
            eg_update_node_state_task, sg_update_node_state_task,
            sg_exec_node_task
        ] = self._generate_and_test(True,
                                    num_initial_executions=1,
                                    num_tasks_generated=3,
                                    num_new_executions=1,
                                    num_active_executions=1,
                                    expected_exec_nodes=[self._stats_gen])
        self.assertTrue(
            task_lib.is_update_node_state_task(eg_update_node_state_task))
        self.assertEqual('my_example_gen',
                         eg_update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.COMPLETE,
                         eg_update_node_state_task.state)
        self.assertTrue(
            task_lib.is_update_node_state_task(sg_update_node_state_task))
        self.assertEqual('my_statistics_gen',
                         sg_update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.RUNNING,
                         sg_update_node_state_task.state)
        self.assertTrue(task_lib.is_exec_node_task(sg_exec_node_task))
 def _generate_and_test(self, use_task_queue, num_initial_executions,
                        num_tasks_generated, num_new_executions,
                        num_active_executions):
   """Generates tasks and tests the effects."""
   with self._mlmd_connection as m:
     executions = m.store.get_executions()
     self.assertLen(
         executions, num_initial_executions,
         'Expected {} execution(s) in MLMD.'.format(num_initial_executions))
     pipeline_state = pstate.PipelineState.new(m, self._pipeline)
     task_gen = sptg.SyncPipelineTaskGenerator(
         m, pipeline_state, self._task_queue.contains_task_id,
         self._mock_service_job_manager)
     tasks = task_gen.generate()
     self.assertLen(
         tasks, num_tasks_generated,
         'Expected {} task(s) to be generated.'.format(num_tasks_generated))
     executions = m.store.get_executions()
     num_total_executions = num_initial_executions + num_new_executions
     self.assertLen(
         executions, num_total_executions,
         'Expected {} execution(s) in MLMD.'.format(num_total_executions))
     active_executions = [
         e for e in executions
         if e.last_known_state == metadata_store_pb2.Execution.RUNNING
     ]
     self.assertLen(
         active_executions, num_active_executions,
         'Expected {} active execution(s) in MLMD.'.format(
             num_active_executions))
     if use_task_queue:
       for task in tasks:
         if task_lib.is_exec_node_task(task):
           self._task_queue.enqueue(task)
     return tasks, active_executions
示例#4
0
def run_generator(mlmd_connection,
                  generator_class,
                  pipeline,
                  task_queue,
                  use_task_queue,
                  service_job_manager,
                  ignore_update_node_state_tasks=False,
                  fail_fast=None):
    """Generates tasks for testing."""
    with mlmd_connection as m:
        pipeline_state = get_or_create_pipeline_state(m, pipeline)
        generator_params = dict(
            mlmd_handle=m,
            is_task_id_tracked_fn=task_queue.contains_task_id,
            service_job_manager=service_job_manager)
        if fail_fast is not None:
            generator_params['fail_fast'] = fail_fast
        task_gen = generator_class(**generator_params)
        tasks = task_gen.generate(pipeline_state)
        if use_task_queue:
            for task in tasks:
                if task_lib.is_exec_node_task(task):
                    task_queue.enqueue(task)
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                with pipeline_state:
                    with pipeline_state.node_state_update_context(
                            task.node_uid) as node_state:
                        node_state.update(task.state, task.status)
    if ignore_update_node_state_tasks:
        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]
    return tasks
示例#5
0
 def _handle_task(self, task: task_lib.Task) -> None:
   """Dispatches task to the task specific handler."""
   if task_lib.is_exec_node_task(task):
     self._handle_exec_node_task(typing.cast(task_lib.ExecNodeTask, task))
   elif task_lib.is_cancel_node_task(task):
     self._handle_cancel_node_task(typing.cast(task_lib.CancelNodeTask, task))
   else:
     raise RuntimeError('Cannot dispatch bad task: {}'.format(task))
示例#6
0
def _make_executor_output(task, code=status_lib.Code.OK):
  assert task_lib.is_exec_node_task(task)
  executor_output = execution_result_pb2.ExecutorOutput()
  for key, artifacts in task.output_artifacts.items():
    for artifact in artifacts:
      executor_output.output_artifacts[key].artifacts.add().CopyFrom(
          artifact.mlmd_artifact)
  executor_output.execution_result.code = code
  return executor_output
示例#7
0
def run_generator_and_test(test_case,
                           mlmd_connection,
                           generator_class,
                           pipeline,
                           task_queue,
                           use_task_queue,
                           service_job_manager,
                           num_initial_executions,
                           num_tasks_generated,
                           num_new_executions,
                           num_active_executions,
                           expected_exec_nodes=None,
                           ignore_node_ids=None):
    """Runs generator.generate() and tests the effects."""
    if service_job_manager is None:
        service_job_manager = service_jobs.DummyServiceJobManager()
    with mlmd_connection as m:
        executions = m.store.get_executions()
        test_case.assertLen(
            executions, num_initial_executions,
            f'Expected {num_initial_executions} execution(s) in MLMD.')
        pipeline_state = pstate.PipelineState.new(m, pipeline)
        generator_params = dict(
            mlmd_handle=m,
            pipeline_state=pipeline_state,
            is_task_id_tracked_fn=task_queue.contains_task_id,
            service_job_manager=service_job_manager)
        if generator_class == asptg.AsyncPipelineTaskGenerator:
            generator_params['ignore_node_ids'] = ignore_node_ids
        task_gen = generator_class(**generator_params)
        tasks = task_gen.generate()
        test_case.assertLen(
            tasks, num_tasks_generated,
            f'Expected {num_tasks_generated} task(s) to be generated.')
        executions = m.store.get_executions()
        num_total_executions = num_initial_executions + num_new_executions
        test_case.assertLen(
            executions, num_total_executions,
            f'Expected {num_total_executions} execution(s) in MLMD.')
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        test_case.assertLen(
            active_executions, num_active_executions,
            f'Expected {num_active_executions} active execution(s) in MLMD.')
        if expected_exec_nodes:
            for i, task in enumerate(tasks):
                _verify_exec_node_task(test_case, pipeline,
                                       expected_exec_nodes[i],
                                       active_executions[i].id, task)
        if use_task_queue:
            for task in tasks:
                if task_lib.is_exec_node_task(task):
                    task_queue.enqueue(task)
        return tasks
示例#8
0
def run_generator_and_test(test_case,
                           mlmd_connection,
                           generator_class,
                           pipeline,
                           task_queue,
                           use_task_queue,
                           service_job_manager,
                           num_initial_executions,
                           num_tasks_generated,
                           num_new_executions,
                           num_active_executions,
                           expected_exec_nodes=None,
                           ignore_update_node_state_tasks=False,
                           fail_fast=None):
    """Runs generator.generate() and tests the effects."""
    if service_job_manager is None:
        service_job_manager = service_jobs.DummyServiceJobManager()
    with mlmd_connection as m:
        executions = get_non_orchestrator_executions(m)
        test_case.assertLen(
            executions, num_initial_executions,
            f'Expected {num_initial_executions} execution(s) in MLMD.')
    tasks = run_generator(
        mlmd_connection,
        generator_class,
        pipeline,
        task_queue,
        use_task_queue,
        service_job_manager,
        ignore_update_node_state_tasks=ignore_update_node_state_tasks,
        fail_fast=fail_fast)
    with mlmd_connection as m:
        test_case.assertLen(
            tasks, num_tasks_generated,
            f'Expected {num_tasks_generated} task(s) to be generated.')
        executions = get_non_orchestrator_executions(m)
        num_total_executions = num_initial_executions + num_new_executions
        test_case.assertLen(
            executions, num_total_executions,
            f'Expected {num_total_executions} execution(s) in MLMD.')
        active_executions = [
            e for e in executions if execution_lib.is_execution_active(e)
        ]
        test_case.assertLen(
            active_executions, num_active_executions,
            f'Expected {num_active_executions} active execution(s) in MLMD.')
        if expected_exec_nodes:
            for i, task in enumerate(t for t in tasks
                                     if task_lib.is_exec_node_task(t)):
                _verify_exec_node_task(test_case, pipeline,
                                       expected_exec_nodes[i],
                                       active_executions[i].id, task)
        return tasks
示例#9
0
    def create_task_scheduler(cls: Type[T], mlmd_handle: metadata.Metadata,
                              pipeline: pipeline_pb2.Pipeline,
                              task: task_lib.Task) -> TaskScheduler:
        """Creates a task scheduler for the given task.

    The task is matched as follows:
    1. The node type name of the node associated with the task is looked up in
       the registry and a scheduler is instantiated if present.
    2. Next, the executor spec url of the node (if one exists) is looked up in
       the registry and a scheduler is instantiated if present. This assumes
       deployment_config packed in the pipeline IR is of type
       `IntermediateDeploymentConfig`.
    3. Lastly, a ValueError is raised if no match can be found.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: The pipeline IR.
      task: The task that needs to be scheduled.

    Returns:
      An instance of `TaskScheduler` for the given task.

    Raises:
      NotImplementedError: Raised if not an `ExecNodeTask`.
      ValueError: If a scheduler could not be found in the registry for the
        given task.
    """

        if not task_lib.is_exec_node_task(task):
            raise NotImplementedError(
                'Can create a task scheduler only for an `ExecNodeTask`.')
        task = typing.cast(task_lib.ExecNodeTask, task)

        try:
            scheduler_class = cls._scheduler_class_for_node_type(task)
        except ValueError as e1:
            try:
                scheduler_class = cls._scheduler_class_for_executor_spec(
                    pipeline, task)
            except ValueError as e2:
                raise ValueError(
                    f'No task scheduler found: {e1}, {e2}') from None

        return scheduler_class(mlmd_handle=mlmd_handle,
                               pipeline=pipeline,
                               task=task)
示例#10
0
  def test_task_generation_when_node_stopped(self, stop_transform):
    """Tests stopped nodes are ignored when generating tasks."""
    # Simulate that ExampleGen has already completed successfully.
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    # Generate once.
    num_initial_executions = 1
    if stop_transform:
      num_tasks_generated = 1
      num_new_executions = 0
      num_active_executions = 0
      with self._mlmd_connection as m:
        pipeline_state = test_utils.get_or_create_pipeline_state(
            m, self._pipeline)
        with pipeline_state:
          with pipeline_state.node_state_update_context(
              task_lib.NodeUid.from_pipeline_node(
                  self._pipeline, self._transform)) as node_state:
            node_state.update(pstate.NodeState.STOPPING,
                              status_lib.Status(code=status_lib.Code.CANCELLED))
    else:
      num_tasks_generated = 3
      num_new_executions = 1
      num_active_executions = 1
    tasks = self._generate_and_test(
        True,
        num_initial_executions=num_initial_executions,
        num_tasks_generated=num_tasks_generated,
        num_new_executions=num_new_executions,
        num_active_executions=num_active_executions)
    self.assertLen(tasks, num_tasks_generated)

    if stop_transform:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
    else:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
      self.assertTrue(task_lib.is_update_node_state_task(tasks[1]))
      self.assertEqual(pstate.NodeState.RUNNING, tasks[1].state)
      self.assertTrue(task_lib.is_exec_node_task(tasks[2]))
示例#11
0
    def test_node_success(self):
        """Tests task generation when a node execution succeeds."""
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        [stats_gen_task
         ] = self._generate_and_test(False,
                                     num_initial_executions=1,
                                     num_tasks_generated=1,
                                     num_new_executions=1,
                                     num_active_executions=1,
                                     ignore_update_node_state_tasks=True)

        # Finish stats-gen execution.
        self._finish_node_execution(False, stats_gen_task)

        [
            stats_gen_update_node_state_task,
            schema_gen_update_node_state_task, schema_gen_exec_node_task
        ] = self._generate_and_test(False,
                                    num_initial_executions=2,
                                    num_tasks_generated=3,
                                    num_new_executions=1,
                                    num_active_executions=1,
                                    expected_exec_nodes=[self._schema_gen])
        self.assertTrue(
            task_lib.is_update_node_state_task(
                stats_gen_update_node_state_task))
        self.assertEqual('my_statistics_gen',
                         stats_gen_update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.COMPLETE,
                         stats_gen_update_node_state_task.state)
        self.assertTrue(
            task_lib.is_update_node_state_task(
                schema_gen_update_node_state_task))
        self.assertEqual('my_schema_gen',
                         schema_gen_update_node_state_task.node_uid.node_id)
        self.assertEqual(pstate.NodeState.RUNNING,
                         schema_gen_update_node_state_task.state)
        self.assertTrue(task_lib.is_exec_node_task(schema_gen_exec_node_task))
示例#12
0
    def create_task_scheduler(cls: Type[T], mlmd_handle: metadata.Metadata,
                              pipeline: pipeline_pb2.Pipeline,
                              task: task_lib.Task) -> TaskScheduler:
        """Creates a task scheduler for the given task.

    Note that this assumes deployment_config packed in the pipeline IR is of
    type `IntermediateDeploymentConfig`. This detail may change in the future.

    Args:
      mlmd_handle: A handle to the MLMD db.
      pipeline: The pipeline IR.
      task: The task that needs to be scheduled.

    Returns:
      An instance of `TaskScheduler` for the given task.

    Raises:
      NotImplementedError: Raised if not an `ExecNodeTask`.
      ValueError: Deployment config not present in the IR proto or if executor
        spec for the node corresponding to `task` not configured in the IR.
    """
        if not task_lib.is_exec_node_task(task):
            raise NotImplementedError(
                'Can create a task scheduler only for an `ExecNodeTask`.')
        task = typing.cast(task_lib.ExecNodeTask, task)
        # TODO(b/170383494): Decide which DeploymentConfig to use.
        if not pipeline.deployment_config.Is(
                pipeline_pb2.IntermediateDeploymentConfig.DESCRIPTOR):
            raise ValueError('No deployment config found in pipeline IR.')
        depl_config = pipeline_pb2.IntermediateDeploymentConfig()
        pipeline.deployment_config.Unpack(depl_config)
        node_id = task.node_uid.node_id
        if node_id not in depl_config.executor_specs:
            raise ValueError(
                'Executor spec for node id `{}` not found in pipeline IR.'.
                format(node_id))
        executor_spec_type_url = depl_config.executor_specs[node_id].type_url
        return cls._task_scheduler_registry[executor_spec_type_url](
            mlmd_handle=mlmd_handle, pipeline=pipeline, task=task)
示例#13
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
示例#14
0
 def _publish(**kwargs):
     task = kwargs['task']
     assert task_lib.is_exec_node_task(task)
     if task.node_uid.node_id == 'Transform':
         raise ValueError('test error')
     return mock.DEFAULT
示例#15
0
  def test_task_generation(self, use_task_queue):
    """Tests async pipeline task generation.

    Args:
      use_task_queue: If task queue is enabled, new tasks are only generated if
        a task with the same task_id does not already exist in the queue.
        `use_task_queue=False` is useful to test the case of task generation
        when task queue is empty (for eg: due to orchestrator restart).
    """
    # Simulate that ExampleGen has already completed successfully.
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    # Generate once.
    [update_example_gen_task, update_transform_task,
     exec_transform_task] = self._generate_and_test(
         use_task_queue,
         num_initial_executions=1,
         num_tasks_generated=3,
         num_new_executions=1,
         num_active_executions=1,
         expected_exec_nodes=[self._transform])
    self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_transform_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state)
    self.assertTrue(task_lib.is_exec_node_task(exec_transform_task))

    self._mock_service_job_manager.ensure_node_services.assert_has_calls([
        mock.call(mock.ANY, self._example_gen.node_info.id),
        mock.call(mock.ANY, self._transform.node_info.id)
    ])

    # No new effects if generate called again.
    tasks = self._generate_and_test(
        use_task_queue,
        num_initial_executions=2,
        num_tasks_generated=1 if use_task_queue else 3,
        num_new_executions=0,
        num_active_executions=1,
        expected_exec_nodes=[] if use_task_queue else [self._transform])
    if not use_task_queue:
      exec_transform_task = tasks[2]

    # Mark transform execution complete.
    self._finish_node_execution(use_task_queue, exec_transform_task)

    # Trainer execution task should be generated next.
    [
        update_example_gen_task, update_transform_task, update_trainer_task,
        exec_trainer_task
    ] = self._generate_and_test(
        use_task_queue,
        num_initial_executions=2,
        num_tasks_generated=4,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._trainer])
    self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_transform_task))
    self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state)
    self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task))

    # Mark the trainer execution complete.
    self._finish_node_execution(use_task_queue, exec_trainer_task)

    # Only UpdateNodeStateTask are generated as there are no new inputs.
    tasks = self._generate_and_test(
        use_task_queue,
        num_initial_executions=3,
        num_tasks_generated=3,
        num_new_executions=0,
        num_active_executions=0)
    for task in tasks:
      self.assertTrue(task_lib.is_update_node_state_task(task))
      self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)

    # Fake another ExampleGen run.
    test_utils.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                    1)

    # Both transform and trainer tasks should be generated as they both find
    # new inputs.
    [
        update_example_gen_task, update_transform_task, exec_transform_task,
        update_trainer_task, exec_trainer_task
    ] = self._generate_and_test(
        use_task_queue,
        num_initial_executions=4,
        num_tasks_generated=5,
        num_new_executions=2,
        num_active_executions=2,
        expected_exec_nodes=[self._transform, self._trainer])
    self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_transform_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_transform_task.state)
    self.assertTrue(task_lib.is_exec_node_task(exec_transform_task))
    self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state)
    self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task))

    # Re-generation will produce the same tasks when task queue disabled.
    tasks = self._generate_and_test(
        use_task_queue,
        num_initial_executions=6,
        num_tasks_generated=1 if use_task_queue else 5,
        num_new_executions=0,
        num_active_executions=2,
        expected_exec_nodes=[]
        if use_task_queue else [self._transform, self._trainer])
    if not use_task_queue:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
      self.assertTrue(task_lib.is_update_node_state_task(tasks[1]))
      self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
      self.assertTrue(task_lib.is_exec_node_task(tasks[2]))
      self.assertTrue(task_lib.is_update_node_state_task(tasks[3]))
      self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
      self.assertTrue(task_lib.is_exec_node_task(tasks[4]))
      exec_transform_task = tasks[2]
      exec_trainer_task = tasks[4]
    else:
      self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
      self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)

    # Mark transform execution complete.
    self._finish_node_execution(use_task_queue, exec_transform_task)

    # Mark the trainer execution complete.
    self._finish_node_execution(use_task_queue, exec_trainer_task)

    # Trainer should be triggered again due to transform producing new output.
    [
        update_example_gen_task, update_transform_task, update_trainer_task,
        exec_trainer_task
    ] = self._generate_and_test(
        use_task_queue,
        num_initial_executions=6,
        num_tasks_generated=4,
        num_new_executions=1,
        num_active_executions=1,
        expected_exec_nodes=[self._trainer])
    self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_transform_task))
    self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_trainer_task.state)
    self.assertTrue(task_lib.is_exec_node_task(exec_trainer_task))

    # Finally, no new tasks once trainer completes.
    self._finish_node_execution(use_task_queue, exec_trainer_task)
    [update_example_gen_task, update_transform_task,
     update_trainer_task] = self._generate_and_test(
         use_task_queue,
         num_initial_executions=7,
         num_tasks_generated=3,
         num_new_executions=0,
         num_active_executions=0)
    self.assertTrue(task_lib.is_update_node_state_task(update_example_gen_task))
    self.assertEqual(pstate.NodeState.RUNNING, update_example_gen_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_transform_task))
    self.assertEqual(pstate.NodeState.STARTED, update_transform_task.state)
    self.assertTrue(task_lib.is_update_node_state_task(update_trainer_task))
    self.assertEqual(pstate.NodeState.STARTED, update_trainer_task.state)

    if use_task_queue:
      self.assertTrue(self._task_queue.is_empty())
示例#16
0
    def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active,
                                            mock_async_task_gen,
                                            mock_sync_task_gen):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline1_execution = pipeline_ops._initiate_pipeline_stop(
                m, task_lib.PipelineUid.from_pipeline(pipeline1))

            task_queue = tq.TaskQueue()

            # For the stop-initiated pipeline, "Transform" execution task is in queue,
            # "Trainer" has an active execution in MLMD but no task in queue,
            # "Evaluator" has no active execution.
            task_queue.enqueue(
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Transform')))
            transform_task = task_queue.dequeue(
            )  # simulates task being processed
            mock_gen_task_from_active.side_effect = [
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Trainer'),
                                                 is_cancelled=True), None,
                None, None, None
            ]

            pipeline_ops.generate_tasks(m, task_queue)

            # There are no active pipelines so these shouldn't be called.
            mock_async_task_gen.assert_not_called()
            mock_sync_task_gen.assert_not_called()

            # Simulate finishing the "Transform" ExecNodeTask.
            task_queue.task_done(transform_task)

            # CancelNodeTask for the "Transform" ExecNodeTask should be next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual('Transform', task.node_uid.node_id)

            # ExecNodeTask for "Trainer" is next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual('Trainer', task.node_uid.node_id)

            self.assertTrue(task_queue.is_empty())

            mock_gen_task_from_active.assert_has_calls([
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[1].pipeline_node,
                          mock.ANY,
                          is_cancelled=True),
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[2].pipeline_node,
                          mock.ANY,
                          is_cancelled=True)
            ])
            self.assertEqual(2, mock_gen_task_from_active.call_count)

            # Pipeline execution should continue to be active since active node
            # executions were found in the last call to `generate_tasks`.
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertTrue(execution_lib.is_execution_active(execution))

            # Call `generate_tasks` again; this time there are no more active node
            # executions so the pipeline should be marked as cancelled.
            pipeline_ops.generate_tasks(m, task_queue)
            self.assertTrue(task_queue.is_empty())
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                             execution.last_known_state)
示例#17
0
    def setUp(self):
        super().setUp()
        pipeline_root = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self.id())

        # Makes sure multiple connections within a test always connect to the same
        # MLMD instance.
        metadata_path = os.path.join(pipeline_root, 'metadata', 'metadata.db')
        self._metadata_path = metadata_path
        connection_config = metadata.sqlite_metadata_connection_config(
            metadata_path)
        connection_config.sqlite.SetInParent()
        self._mlmd_connection = metadata.Metadata(
            connection_config=connection_config)

        # Sets up the pipeline.
        pipeline = test_async_pipeline.create_pipeline()

        # Extracts components.
        self._example_gen = pipeline.nodes[0].pipeline_node
        self._transform = pipeline.nodes[1].pipeline_node
        self._trainer = pipeline.nodes[2].pipeline_node

        # Pack deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='fake.ClassPath')
        deployment_config.executor_specs[self._trainer.node_info.id].Pack(
            executor_spec)
        deployment_config.executor_specs[self._transform.node_info.id].Pack(
            executor_spec)
        self._type_url = deployment_config.executor_specs[
            self._trainer.node_info.id].type_url
        pipeline.deployment_config.Pack(deployment_config)
        self._pipeline = pipeline
        self._pipeline_info = pipeline.pipeline_info
        self._pipeline_runtime_spec = pipeline.runtime_spec
        self._pipeline_runtime_spec.pipeline_root.field_value.string_value = (
            pipeline_root)

        ts.TaskSchedulerRegistry.clear()
        self._task_queue = tq.TaskQueue()

        # Run fake example-gen to prepare downstreams component triggers.
        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        # Task generator should produce two tasks for transform. The first one is
        # UpdateNodeStateTask and the second one is ExecNodeTask.
        with self._mlmd_connection as m:
            pipeline_state = pstate.PipelineState.new(m, self._pipeline)
            tasks = asptg.AsyncPipelineTaskGenerator(
                m, self._task_queue.contains_task_id,
                service_jobs.DummyServiceJobManager()).generate(pipeline_state)
        self.assertLen(tasks, 2)
        self.assertTrue(task_lib.is_update_node_state_task(tasks[0]))
        self.assertEqual(pstate.NodeState.RUNNING, tasks[0].state)
        self.assertEqual('my_transform', tasks[0].node_uid.node_id)
        self.assertTrue(task_lib.is_exec_node_task(tasks[1]))
        self.assertEqual('my_transform', tasks[1].node_uid.node_id)
        self.assertTrue(os.path.exists(tasks[1].stateful_working_dir))
        self.assertTrue(os.path.exists(tasks[1].tmp_dir))
        self._task = tasks[1]
        self._output_artifact_uri = self._task.output_artifacts[
            'transform_graph'][0].uri
        self.assertTrue(os.path.exists(self._output_artifact_uri))
        self._task_queue.enqueue(self._task)

        # There should be 1 active execution in MLMD.
        with self._mlmd_connection as m:
            executions = m.store.get_executions()
        active_executions = [
            e for e in executions
            if e.last_known_state == metadata_store_pb2.Execution.RUNNING
        ]
        self.assertLen(active_executions, 1)

        # Active execution id.
        self._execution_id = active_executions[0].id
示例#18
0
    def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen,
                                                   mock_sync_task_gen):
        with self._mlmd_connection as m:
            # One active pipeline.
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)

            # Another active pipeline (with previously completed execution).
            pipeline2 = _test_pipeline('pipeline2')
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)
            execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution2])
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)

            # Inactive pipelines should be ignored.
            pipeline3 = _test_pipeline('pipeline3')
            execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3)
            execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution3])

            # For active pipelines pipeline1 and pipeline2, there are a couple of
            # active executions.
            def _exec_node_tasks():
                for pipeline_id in ('pipeline1', 'pipeline2'):
                    yield [
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Transform')),
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Trainer'))
                    ]

            mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks(
            )

            task_queue = tq.TaskQueue()
            pipeline_ops.generate_tasks(m, task_queue)

            self.assertEqual(
                2, mock_async_task_gen.return_value.generate.call_count)
            mock_sync_task_gen.assert_not_called()

            # Verify that tasks are enqueued in the expected order.
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline1',
                                 task.node_uid.pipeline_uid.pipeline_id)
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline2',
                                 task.node_uid.pipeline_uid.pipeline_id)
            self.assertTrue(task_queue.is_empty())
示例#19
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    execution = pipeline_state.execution
    assert execution.last_known_state in (metadata_store_pb2.Execution.NEW,
                                          metadata_store_pb2.Execution.RUNNING)
    if execution.last_known_state != metadata_store_pb2.Execution.RUNNING:
        updated_execution = copy.deepcopy(execution)
        updated_execution.last_known_state = metadata_store_pb2.Execution.RUNNING
        mlmd_handle.store.put_executions([updated_execution])

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        # Create cancellation tasks for stop-initiated nodes if necessary.
        stop_initiated_nodes = _get_stop_initiated_nodes(pipeline_state)
        for node in stop_initiated_nodes:
            if service_job_manager.is_pure_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
            elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline, node,
                                                  task_queue):
                pass
            elif service_job_manager.is_mixed_service_node(
                    pipeline_state, node.node_info.id):
                service_job_manager.stop_node_services(pipeline_state,
                                                       node.node_info.id)
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, pipeline_state, task_queue.contains_task_id,
            service_job_manager,
            set(n.node_info.id for n in stop_initiated_nodes))
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate()

    with pipeline_state:
        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            elif task_lib.is_finalize_node_task(task):
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC
                task = typing.cast(task_lib.FinalizeNodeTask, task)
                pipeline_state.initiate_node_stop(task.node_uid, task.status)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)
示例#20
0
    def test_conditional_execution(self, evaluate):
        """Tests conditionals in the pipeline.

    Args:
      evaluate: Whether to run the conditional evaluator.
    """
        # Check the expected terminal nodes.
        layers = sptg._topsorted_layers(self._pipeline)
        self.assertEqual(
            {
                self._example_validator.node_info.id,
                self._chore_b.node_info.id,
                self._evaluator.node_info.id,
            }, sptg._terminal_node_ids(layers))

        # Start executing the pipeline:

        test_utils.fake_example_gen_run(self._mlmd_connection,
                                        self._example_gen, 1, 1)

        self._run_next(False, expect_nodes=[self._stats_gen])
        self._run_next(False, expect_nodes=[self._schema_gen])
        self._run_next(False,
                       expect_nodes=[self._example_validator, self._transform])

        # Evaluator is run conditionally based on whether the Model artifact
        # produced by the trainer has a custom property evaluate=1.
        self._run_next(
            False,
            expect_nodes=[self._trainer],
            artifact_custom_properties={'evaluate': 1} if evaluate else None)

        tasks = self._generate(False)
        [evaluator_update_node_state_task] = [
            t for t in tasks if task_lib.is_update_node_state_task(t)
            and t.node_uid.node_id == 'my_evaluator'
        ]
        self.assertEqual(
            pstate.NodeState.RUNNING if evaluate else pstate.NodeState.SKIPPED,
            evaluator_update_node_state_task.state)

        exec_node_tasks = [t for t in tasks if task_lib.is_exec_node_task(t)]
        if evaluate:
            [chore_a_exec_node_task,
             evaluator_exec_node_task] = exec_node_tasks
            self.assertEqual('chore_a',
                             chore_a_exec_node_task.node_uid.node_id)
            self.assertEqual('my_evaluator',
                             evaluator_exec_node_task.node_uid.node_id)
            self._finish_node_execution(False, chore_a_exec_node_task)
            self._finish_node_execution(False, evaluator_exec_node_task)
        else:
            [chore_a_exec_node_task] = exec_node_tasks
            self.assertEqual('chore_a',
                             chore_a_exec_node_task.node_uid.node_id)
            self._finish_node_execution(False, chore_a_exec_node_task)

        self._run_next(False, expect_nodes=[self._chore_b])

        # All nodes executed, finalization task should be produced.
        [finalize_task] = self._generate(False, True)
        self.assertTrue(task_lib.is_finalize_pipeline_task(finalize_task))
示例#21
0
    def __call__(self) -> List[task_lib.Task]:
        layers = _topsorted_layers(self._pipeline)
        terminal_node_ids = _terminal_node_ids(layers)
        exec_node_tasks = []
        update_node_state_tasks = []
        successful_node_ids = set()
        failed_nodes_dict: Dict[str, status_lib.Status] = {}
        finalize_pipeline_task = None
        for layer_nodes in layers:
            for node in layer_nodes:
                node_id = node.node_info.id
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    self._pipeline, node)
                node_state = self._node_states_dict[node_uid]
                if node_state.is_success():
                    successful_node_ids.add(node_id)
                    continue
                if node_state.is_failure():
                    failed_nodes_dict[node_id] = node_state.status
                    continue
                if not self._upstream_nodes_successful(node,
                                                       successful_node_ids):
                    continue
                tasks = self._generate_tasks_for_node(node)
                for task in tasks:
                    if task_lib.is_update_node_state_task(task):
                        task = typing.cast(task_lib.UpdateNodeStateTask, task)
                        if pstate.is_node_state_success(task.state):
                            successful_node_ids.add(node_id)
                        elif pstate.is_node_state_failure(task.state):
                            failed_nodes_dict[node_id] = task.status
                            if self._fail_fast:
                                finalize_pipeline_task = self._abort_task(
                                    task.status.message)
                        update_node_state_tasks.append(task)
                    elif task_lib.is_exec_node_task(task):
                        exec_node_tasks.append(task)

                if finalize_pipeline_task:
                    break

            if finalize_pipeline_task:
                break

        if not self._fail_fast and failed_nodes_dict:
            assert not finalize_pipeline_task
            node_by_id = _node_by_id(self._pipeline)
            # Collect nodes that cannot be run because they have a failed ancestor.
            unrunnable_node_ids = set()
            for node_id in failed_nodes_dict:
                unrunnable_node_ids |= _descendants(node_by_id, node_id)
            # Nodes that are still runnable have neither succeeded nor failed, and
            # don't have a failed ancestor.
            runnable_node_ids = node_by_id.keys() - (unrunnable_node_ids
                                                     | successful_node_ids |
                                                     failed_nodes_dict.keys())
            # If there are no runnable nodes, we can abort the pipeline.
            if not runnable_node_ids:
                finalize_pipeline_task = self._abort_task(
                    f'Cannot make progress due to node failures: {failed_nodes_dict}'
                )

        result = update_node_state_tasks
        if finalize_pipeline_task:
            result.append(finalize_pipeline_task)
        elif terminal_node_ids <= successful_node_ids:
            # If all terminal nodes are successful, the pipeline can be finalized.
            result.append(
                task_lib.FinalizePipelineTask(
                    pipeline_uid=self._pipeline_uid,
                    status=status_lib.Status(code=status_lib.Code.OK)))
        else:
            result.extend(exec_node_tasks)
        return result
示例#22
0
  def test_orchestrate_active_pipelines(self, mock_async_task_gen,
                                        mock_sync_task_gen):
    with self._mlmd_connection as m:
      # Sync and async active pipelines.
      async_pipelines = [
          _test_pipeline('pipeline1'),
          _test_pipeline('pipeline2'),
      ]
      sync_pipelines = [
          _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC),
          _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC),
      ]

      for pipeline in async_pipelines + sync_pipelines:
        pipeline_ops.initiate_pipeline_start(m, pipeline)

      # Active executions for active async pipelines.
      mock_async_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[0]),
                      node_id='Transform'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[1]),
                      node_id='Trainer'))
          ],
      ]

      # Active executions for active sync pipelines.
      mock_sync_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[0]),
                      node_id='Trainer'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[1]),
                      node_id='Validator'))
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())

      self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
      self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

      # Verify that tasks are enqueued in the expected order.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid)
      self.assertTrue(task_queue.is_empty())
示例#23
0
  def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
                                    mock_async_task_gen, mock_sync_task_gen):
    with self._mlmd_connection as m:
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      mock_service_job_manager.is_mixed_service_node.side_effect = (
          lambda _, node_id: node_id == 'Transform')

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        pipeline_state.initiate_stop(
            status_lib.Status(code=status_lib.Code.CANCELLED))
      pipeline_execution = pipeline_state.execution

      task_queue = tq.TaskQueue()

      # For the stop-initiated pipeline, "Transform" execution task is in queue,
      # "Trainer" has an active execution in MLMD but no task in queue,
      # "Evaluator" has no active execution.
      task_queue.enqueue(
          test_utils.create_exec_node_task(
              task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Transform')))
      transform_task = task_queue.dequeue()  # simulates task being processed
      mock_gen_task_from_active.side_effect = [
          test_utils.create_exec_node_task(
              node_uid=task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Trainer'),
              is_cancelled=True), None, None, None, None
      ]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

      # There are no active pipelines so these shouldn't be called.
      mock_async_task_gen.assert_not_called()
      mock_sync_task_gen.assert_not_called()

      # stop_node_services should be called for ExampleGen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')
      mock_service_job_manager.reset_mock()

      task_queue.task_done(transform_task)  # Pop out transform task.

      # CancelNodeTask for the "Transform" ExecNodeTask should be next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual('Transform', task.node_uid.node_id)

      # ExecNodeTask (with is_cancelled=True) for "Trainer" is next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual('Trainer', task.node_uid.node_id)
      self.assertTrue(task.is_cancelled)

      self.assertTrue(task_queue.is_empty())

      mock_gen_task_from_active.assert_has_calls([
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[2].pipeline_node,
              mock.ANY,
              is_cancelled=True),
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[3].pipeline_node,
              mock.ANY,
              is_cancelled=True)
      ])
      self.assertEqual(2, mock_gen_task_from_active.call_count)

      # Pipeline execution should continue to be active since active node
      # executions were found in the last call to `orchestrate`.
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertTrue(execution_lib.is_execution_active(execution))

      # Call `orchestrate` again; this time there are no more active node
      # executions so the pipeline should be marked as cancelled.
      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertTrue(task_queue.is_empty())
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                       execution.last_known_state)

      # stop_node_services should be called on both ExampleGen and Transform
      # which are service nodes.
      mock_service_job_manager.stop_node_services.assert_has_calls(
          [mock.call(mock.ANY, 'ExampleGen'),
           mock.call(mock.ANY, 'Transform')],
          any_order=True)
示例#24
0
def _orchestrate_active_pipeline(
        mlmd_handle: metadata.Metadata, task_queue: tq.TaskQueue,
        service_job_manager: service_jobs.ServiceJobManager,
        pipeline_state: pstate.PipelineState) -> None:
    """Orchestrates active pipeline."""
    pipeline = pipeline_state.pipeline
    with pipeline_state:
        assert pipeline_state.is_active()
        if pipeline_state.get_pipeline_execution_state() != (
                metadata_store_pb2.Execution.RUNNING):
            pipeline_state.set_pipeline_execution_state(
                metadata_store_pb2.Execution.RUNNING)
        orchestration_options = pipeline_state.get_orchestration_options()
        logging.info('Orchestration options: %s', orchestration_options)
        deadline_secs = orchestration_options.deadline_secs
        if (pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                and deadline_secs > 0 and time.time() -
                pipeline_state.pipeline_creation_time_secs_since_epoch() >
                deadline_secs):
            logging.error(
                'Aborting pipeline due to exceeding deadline (%s secs); '
                'pipeline uid: %s', deadline_secs, pipeline_state.pipeline_uid)
            pipeline_state.initiate_stop(
                status_lib.Status(
                    code=status_lib.Code.DEADLINE_EXCEEDED,
                    message=('Pipeline aborted due to exceeding deadline '
                             f'({deadline_secs} secs)')))
            return

    def _filter_by_state(node_infos: List[_NodeInfo],
                         state_str: str) -> List[_NodeInfo]:
        return [n for n in node_infos if n.state.state == state_str]

    node_infos = _get_node_infos(pipeline_state)
    stopping_node_infos = _filter_by_state(node_infos,
                                           pstate.NodeState.STOPPING)

    # Tracks nodes stopped in the current iteration.
    stopped_node_infos: List[_NodeInfo] = []

    # Create cancellation tasks for nodes in state STOPPING.
    for node_info in stopping_node_infos:
        if service_job_manager.is_pure_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        elif _maybe_enqueue_cancellation_task(mlmd_handle, pipeline,
                                              node_info.node, task_queue):
            pass
        elif service_job_manager.is_mixed_service_node(
                pipeline_state, node_info.node.node_info.id):
            if service_job_manager.stop_node_services(
                    pipeline_state, node_info.node.node_info.id):
                stopped_node_infos.append(node_info)
        else:
            stopped_node_infos.append(node_info)

    # Change the state of stopped nodes from STOPPING to STOPPED.
    if stopped_node_infos:
        with pipeline_state:
            for node_info in stopped_node_infos:
                node_uid = task_lib.NodeUid.from_pipeline_node(
                    pipeline, node_info.node)
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPED,
                                      node_state.status)

    # Initialize task generator for the pipeline.
    if pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC:
        generator = sync_pipeline_task_gen.SyncPipelineTaskGenerator(
            mlmd_handle,
            task_queue.contains_task_id,
            service_job_manager,
            fail_fast=orchestration_options.fail_fast)
    elif pipeline.execution_mode == pipeline_pb2.Pipeline.ASYNC:
        generator = async_pipeline_task_gen.AsyncPipelineTaskGenerator(
            mlmd_handle, task_queue.contains_task_id, service_job_manager)
    else:
        raise status_lib.StatusNotOkError(
            code=status_lib.Code.FAILED_PRECONDITION,
            message=(
                f'Only SYNC and ASYNC pipeline execution modes supported; '
                f'found pipeline with execution mode: {pipeline.execution_mode}'
            ))

    tasks = generator.generate(pipeline_state)

    with pipeline_state:
        # Handle all the UpdateNodeStateTasks by updating node states.
        for task in tasks:
            if task_lib.is_update_node_state_task(task):
                task = typing.cast(task_lib.UpdateNodeStateTask, task)
                with pipeline_state.node_state_update_context(
                        task.node_uid) as node_state:
                    node_state.update(task.state, task.status)

        tasks = [t for t in tasks if not task_lib.is_update_node_state_task(t)]

        # If there are still nodes in state STARTING, change them to STARTED.
        for node in pstate.get_all_pipeline_nodes(pipeline_state.pipeline):
            node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline_state.pipeline, node)
            with pipeline_state.node_state_update_context(
                    node_uid) as node_state:
                if node_state.state == pstate.NodeState.STARTING:
                    node_state.update(pstate.NodeState.STARTED)

        for task in tasks:
            if task_lib.is_exec_node_task(task):
                task = typing.cast(task_lib.ExecNodeTask, task)
                task_queue.enqueue(task)
            else:
                assert task_lib.is_finalize_pipeline_task(task)
                assert pipeline.execution_mode == pipeline_pb2.Pipeline.SYNC
                assert len(tasks) == 1
                task = typing.cast(task_lib.FinalizePipelineTask, task)
                if task.status.code == status_lib.Code.OK:
                    logging.info('Pipeline run successful; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                else:
                    logging.info('Pipeline run failed; pipeline uid: %s',
                                 pipeline_state.pipeline_uid)
                pipeline_state.initiate_stop(task.status)