Example #1
0
    def test_stop_pipeline_non_existent_or_inactive(self, pipeline):
        with self._mlmd_connection as m:
            # Stop pipeline without creating one.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(
                    m, task_lib.PipelineUid.from_pipeline(pipeline))
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)

            # Initiate pipeline start and mark it completed.
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                pipeline_state.initiate_stop(
                    status_lib.Status(code=status_lib.Code.OK))
                pipeline_state.execution.last_known_state = (
                    metadata_store_pb2.Execution.COMPLETE)

            # Try to initiate stop again.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(m, pipeline_uid)
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)
Example #2
0
  def test_initiate_pipeline_start(self, pipeline):
    with self._mlmd_connection as m:
      # Initiate a pipeline start.
      pipeline_state1 = pipeline_ops.initiate_pipeline_start(m, pipeline)
      self.assertProtoPartiallyEquals(
          pipeline, pipeline_state1.pipeline, ignored_fields=['runtime_spec'])
      self.assertEqual(metadata_store_pb2.Execution.NEW,
                       pipeline_state1.execution.last_known_state)

      # Initiate another pipeline start.
      pipeline2 = _test_pipeline('pipeline2')
      pipeline_state2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)
      self.assertEqual(pipeline2, pipeline_state2.pipeline)
      self.assertEqual(metadata_store_pb2.Execution.NEW,
                       pipeline_state2.execution.last_known_state)

      # Error if attempted to initiate when old one is active.
      with self.assertRaises(status_lib.StatusNotOkError) as exception_context:
        pipeline_ops.initiate_pipeline_start(m, pipeline)
      self.assertEqual(status_lib.Code.ALREADY_EXISTS,
                       exception_context.exception.code)

      # Fine to initiate after the previous one is inactive.
      execution = pipeline_state1.execution
      execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
      m.store.put_executions([execution])
      pipeline_state3 = pipeline_ops.initiate_pipeline_start(m, pipeline)
      self.assertEqual(metadata_store_pb2.Execution.NEW,
                       pipeline_state3.execution.last_known_state)
Example #3
0
 def test_initiate_pipeline_stop(self):
     with self._mlmd_connection as m:
         pipeline1 = _test_pipeline('pipeline1')
         pipeline_ops.initiate_pipeline_start(m, pipeline1)
         pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1)
         pipeline_state = pipeline_ops._initiate_pipeline_stop(
             m, pipeline_uid)
         self.assertTrue(pipeline_state.is_stop_initiated())
Example #4
0
  def test_stop_pipeline_wait_for_inactivation_timeout(self, pipeline):
    with self._mlmd_connection as m:
      pipeline_ops.initiate_pipeline_start(m, pipeline)

      with self.assertRaisesRegex(
          status_lib.StatusNotOkError,
          'Timed out.*waiting for execution inactivation.'
      ) as exception_context:
        pipeline_ops.stop_pipeline(
            m, task_lib.PipelineUid.from_pipeline(pipeline), timeout_secs=1.0)
      self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
                       exception_context.exception.code)
Example #5
0
    def test_initiate_pipeline_stop(self):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1)
            pipeline_ops._initiate_pipeline_stop(m, pipeline_uid)

            # Verify MLMD state.
            executions = m.store.get_executions_by_type(
                pipeline_ops._ORCHESTRATOR_RESERVED_ID)
            self.assertLen(executions, 1)
            execution = executions[0]
            self.assertEqual(
                1, execution.custom_properties[
                    pipeline_ops._STOP_INITIATED].int_value)
Example #6
0
def create_sample_pipeline(m: metadata.Metadata,
                           pipeline_id: str,
                           run_num: int,
                           export_ir_path: str = '',
                           external_ir_file: str = '',
                           deployment_config: Optional[message.Message] = None,
                           execute_nodes_func: Callable[
                               [metadata.Metadata, pipeline_pb2.Pipeline, int],
                               None] = _execute_nodes):
  """Creates a list of pipeline and node execution."""
  ir_path = _get_ir_path(external_ir_file)
  for i in range(run_num):
    run_id = 'run%02d' % i
    pipeline = _test_pipeline(ir_path, pipeline_id, run_id, deployment_config)
    if export_ir_path:
      output_path = os.path.join(export_ir_path,
                                 '%s_%s.pbtxt' % (pipeline_id, run_id))
      io_utils.write_pbtxt_file(output_path, pipeline)
    pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)
    if not external_ir_file:
      execute_nodes_func(m, pipeline, i)
    if i < run_num - 1:
      with pipeline_state:
        pipeline_state.set_pipeline_execution_state(
            metadata_store_pb2.Execution.COMPLETE)
Example #7
0
    def test_stop_pipeline_non_existent(self):
        with self._mlmd_connection as m:
            # Stop pipeline without creating one.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(
                    m,
                    task_lib.PipelineUid(pipeline_id='foo',
                                         pipeline_run_id=None))
            self.assertEqual(status_lib.Code.NOT_FOUND,
                             exception_context.exception.code)

            # Initiate pipeline start and mark it completed.
            pipeline1 = _test_pipeline('pipeline1')
            execution = pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline1)
            pipeline_ops._initiate_pipeline_stop(m, pipeline_uid)
            execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution])

            # Try to initiate stop again.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.stop_pipeline(m, pipeline_uid)
            self.assertEqual(status_lib.Code.ALREADY_EXISTS,
                             exception_context.exception.code)
Example #8
0
  def test_save_and_remove_pipeline_property(self):
    with self._mlmd_connection as m:
      pipeline1 = _test_pipeline('pipeline1')
      pipeline_state1 = pipeline_ops.initiate_pipeline_start(m, pipeline1)
      property_key = 'test_key'
      property_value = 'bala'
      self.assertIsNone(
          pipeline_state1.execution.custom_properties.get(property_key))
      pipeline_ops.save_pipeline_property(pipeline_state1.mlmd_handle,
                                          pipeline_state1.pipeline_uid,
                                          property_key, property_value)

      with pstate.PipelineState.load(
          m, pipeline_state1.pipeline_uid) as pipeline_state2:
        self.assertIsNotNone(
            pipeline_state2.execution.custom_properties.get(property_key))
        self.assertEqual(
            pipeline_state2.execution.custom_properties[property_key]
            .string_value, property_value)

      pipeline_ops.remove_pipeline_property(pipeline_state2.mlmd_handle,
                                            pipeline_state2.pipeline_uid,
                                            property_key)
      with pstate.PipelineState.load(
          m, pipeline_state2.pipeline_uid) as pipeline_state3:
        self.assertIsNone(
            pipeline_state3.execution.custom_properties.get(property_key))
Example #9
0
 def test_sync_pipeline_run_id_runtime_parameter(self):
     with self._mlmd_connection as m:
         pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
         pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)
         self.assertNotEmpty(pipeline_state.pipeline.runtime_spec.
                             pipeline_run_id.field_value.string_value)
         self.assertEqual(task_lib.PipelineUid(pipeline_id='pipeline1'),
                          pipeline_state.pipeline_uid)
Example #10
0
def create_sample_pipeline(m: metadata.Metadata, pipeline_id: str,
                           run_num: int):
    """Creates a list of pipeline and node execution."""
    for i in range(run_num):
        run_id = 'run%02d' % i
        pipeline = _test_pipeline(pipeline_id, run_id)
        pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)
        _execute_nodes(m, pipeline, i)
        if i < run_num - 1:
            execution = pipeline_state.execution
            execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution])
Example #11
0
    def test_initiate_pipeline_start(self):
        with self._mlmd_connection as m:
            # Initiate a pipeline start.
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)

            # Initiate another pipeline start.
            pipeline2 = _test_pipeline('pipeline2')
            pipeline_ops.initiate_pipeline_start(m, pipeline2)

            # No error raised => context/execution types exist.
            m.store.get_context_type(pipeline_ops._ORCHESTRATOR_RESERVED_ID)
            m.store.get_execution_type(pipeline_ops._ORCHESTRATOR_RESERVED_ID)

            # Verify MLMD state.
            contexts = m.store.get_contexts_by_type(
                pipeline_ops._ORCHESTRATOR_RESERVED_ID)
            self.assertLen(contexts, 2)
            self.assertCountEqual([
                pipeline_ops._orchestrator_context_name(
                    task_lib.PipelineUid.from_pipeline(pipeline1)),
                pipeline_ops._orchestrator_context_name(
                    task_lib.PipelineUid.from_pipeline(pipeline2))
            ], [c.name for c in contexts])

            for context in contexts:
                executions = m.store.get_executions_by_context(context.id)
                self.assertLen(executions, 1)
                self.assertEqual(metadata_store_pb2.Execution.NEW,
                                 executions[0].last_known_state)
                retrieved_pipeline = pipeline_pb2.Pipeline()
                retrieved_pipeline.ParseFromString(
                    base64.b64decode(executions[0].properties[
                        pipeline_ops._PIPELINE_IR].string_value))
                expected_pipeline_id = (
                    pipeline_ops._pipeline_uid_from_context(
                        context).pipeline_id)
                self.assertEqual(_test_pipeline(expected_pipeline_id),
                                 retrieved_pipeline)
Example #12
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
Example #13
0
    def test_initiate_pipeline_start_new_execution(self):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)

            # Error if attempted to initiate when old one is active.
            with self.assertRaises(
                    status_lib.StatusNotOkError) as exception_context:
                pipeline_ops.initiate_pipeline_start(m, pipeline1)
            self.assertEqual(status_lib.Code.ALREADY_EXISTS,
                             exception_context.exception.code)

            # Fine to initiate after the previous one is inactive.
            executions = m.store.get_executions_by_type(
                pipeline_ops._ORCHESTRATOR_RESERVED_ID)
            self.assertLen(executions, 1)
            executions[
                0].last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions(executions)
            execution = pipeline_ops.initiate_pipeline_start(m, pipeline1)
            self.assertEqual(metadata_store_pb2.Execution.NEW,
                             execution.last_known_state)

            # Verify MLMD state.
            contexts = m.store.get_contexts_by_type(
                pipeline_ops._ORCHESTRATOR_RESERVED_ID)
            self.assertLen(contexts, 1)
            self.assertEqual(
                pipeline_ops._orchestrator_context_name(
                    task_lib.PipelineUid.from_pipeline(pipeline1)),
                contexts[0].name)
            executions = m.store.get_executions_by_context(contexts[0].id)
            self.assertLen(executions, 2)
            self.assertCountEqual([
                metadata_store_pb2.Execution.COMPLETE,
                metadata_store_pb2.Execution.NEW
            ], [e.last_known_state for e in executions])
Example #14
0
  def test_handling_finalize_pipeline_task(self, task_gen):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
      pipeline_ops.initiate_pipeline_start(m, pipeline)
      pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
      finalize_reason = status_lib.Status(
          code=status_lib.Code.ABORTED, message='foo bar')
      task_gen.return_value.generate.side_effect = [
          [
              task_lib.FinalizePipelineTask(
                  pipeline_uid=pipeline_uid, status=finalize_reason)
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())
      task_gen.return_value.generate.assert_called_once()
      self.assertTrue(task_queue.is_empty())

      # Load pipeline state and verify stop initiation.
      with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
        self.assertEqual(finalize_reason,
                         pipeline_state.stop_initiated_reason())
Example #15
0
def create_sample_pipeline(m: metadata.Metadata,
                           pipeline_id: str,
                           run_num: int,
                           export_ir_path: str = ''):
    """Creates a list of pipeline and node execution."""
    for i in range(run_num):
        run_id = 'run%02d' % i
        pipeline = _test_pipeline(pipeline_id, run_id)
        if export_ir_path:
            output_path = os.path.join(export_ir_path,
                                       '%s_%s.pbtxt' % (pipeline_id, run_id))
            io_utils.write_pbtxt_file(output_path, pipeline)
        pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)
        _execute_nodes(m, pipeline, i)
        if i < run_num - 1:
            with pipeline_state:
                pipeline_state.execution.last_known_state = (
                    metadata_store_pb2.Execution.COMPLETE)
Example #16
0
  def test_stop_pipeline_wait_for_inactivation(self, pipeline):
    with self._mlmd_connection as m:
      execution = pipeline_ops.initiate_pipeline_start(m, pipeline).execution

      def _inactivate(execution):
        time.sleep(2.0)
        with pipeline_ops._PIPELINE_OPS_LOCK:
          execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
          m.store.put_executions([execution])

      thread = threading.Thread(
          target=_inactivate, args=(copy.deepcopy(execution),))
      thread.start()

      pipeline_ops.stop_pipeline(
          m, task_lib.PipelineUid.from_pipeline(pipeline), timeout_secs=10.0)

      thread.join()
Example #17
0
    def test_stop_pipeline_wait_for_inactivation(self, pipeline):
        with self._mlmd_connection as m:
            pipeline_state = pipeline_ops.initiate_pipeline_start(m, pipeline)

            def _inactivate(pipeline_state):
                time.sleep(2.0)
                with pipeline_ops._PIPELINE_OPS_LOCK:
                    with pipeline_state:
                        pipeline_state.set_pipeline_execution_state(
                            metadata_store_pb2.Execution.COMPLETE)

            thread = threading.Thread(target=_inactivate,
                                      args=(pipeline_state, ))
            thread.start()

            pipeline_ops.stop_pipeline(
                m,
                task_lib.PipelineUid.from_pipeline(pipeline),
                timeout_secs=10.0)

            thread.join()
Example #18
0
  def test_active_pipelines_with_stop_initiated_nodes(self,
                                                      mock_gen_task_from_active,
                                                      mock_async_task_gen):
    with self._mlmd_connection as m:
      pipeline = _test_pipeline('pipeline')
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      example_gen_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[0].pipeline_node)

      transform_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[1].pipeline_node)
      transform_task = test_utils.create_exec_node_task(
          node_uid=transform_node_uid)

      trainer_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[2].pipeline_node)
      trainer_task = test_utils.create_exec_node_task(node_uid=trainer_node_uid)

      evaluator_node_uid = task_lib.NodeUid.from_pipeline_node(
          pipeline, pipeline.nodes[3].pipeline_node)
      evaluator_task = test_utils.create_exec_node_task(
          node_uid=evaluator_node_uid)
      cancelled_evaluator_task = test_utils.create_exec_node_task(
          node_uid=evaluator_node_uid, is_cancelled=True)

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        # Stop example-gen, trainer and evaluator.
        pipeline_state.initiate_node_stop(
            example_gen_node_uid,
            status_lib.Status(code=status_lib.Code.CANCELLED))
        pipeline_state.initiate_node_stop(
            trainer_node_uid, status_lib.Status(code=status_lib.Code.CANCELLED))
        pipeline_state.initiate_node_stop(
            evaluator_node_uid, status_lib.Status(code=status_lib.Code.ABORTED))

      task_queue = tq.TaskQueue()

      # Simulate a new transform execution being triggered.
      mock_async_task_gen.return_value.generate.return_value = [transform_task]
      # Simulate ExecNodeTask for trainer already present in the task queue.
      task_queue.enqueue(trainer_task)
      # Simulate Evaluator having an active execution in MLMD.
      mock_gen_task_from_active.side_effect = [evaluator_task]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertEqual(1, mock_async_task_gen.return_value.generate.call_count)

      # stop_node_services should be called on example-gen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')

      # Verify that tasks are enqueued in the expected order:

      # Pre-existing trainer task.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertEqual(trainer_task, task)

      # CancelNodeTask for trainer.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual(trainer_node_uid, task.node_uid)

      # ExecNodeTask with is_cancelled=True for evaluator.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(cancelled_evaluator_task, task)

      # ExecNodeTask for newly triggered transform node.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertEqual(transform_task, task)

      # No more tasks.
      self.assertTrue(task_queue.is_empty())
Example #19
0
  def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
                                    mock_async_task_gen, mock_sync_task_gen):
    with self._mlmd_connection as m:
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      mock_service_job_manager.is_mixed_service_node.side_effect = (
          lambda _, node_id: node_id == 'Transform')

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        pipeline_state.initiate_stop(
            status_lib.Status(code=status_lib.Code.CANCELLED))
      pipeline_execution = pipeline_state.execution

      task_queue = tq.TaskQueue()

      # For the stop-initiated pipeline, "Transform" execution task is in queue,
      # "Trainer" has an active execution in MLMD but no task in queue,
      # "Evaluator" has no active execution.
      task_queue.enqueue(
          test_utils.create_exec_node_task(
              task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Transform')))
      transform_task = task_queue.dequeue()  # simulates task being processed
      mock_gen_task_from_active.side_effect = [
          test_utils.create_exec_node_task(
              node_uid=task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Trainer'),
              is_cancelled=True), None, None, None, None
      ]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

      # There are no active pipelines so these shouldn't be called.
      mock_async_task_gen.assert_not_called()
      mock_sync_task_gen.assert_not_called()

      # stop_node_services should be called for ExampleGen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')
      mock_service_job_manager.reset_mock()

      task_queue.task_done(transform_task)  # Pop out transform task.

      # CancelNodeTask for the "Transform" ExecNodeTask should be next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual('Transform', task.node_uid.node_id)

      # ExecNodeTask (with is_cancelled=True) for "Trainer" is next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual('Trainer', task.node_uid.node_id)
      self.assertTrue(task.is_cancelled)

      self.assertTrue(task_queue.is_empty())

      mock_gen_task_from_active.assert_has_calls([
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[2].pipeline_node,
              mock.ANY,
              is_cancelled=True),
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[3].pipeline_node,
              mock.ANY,
              is_cancelled=True)
      ])
      self.assertEqual(2, mock_gen_task_from_active.call_count)

      # Pipeline execution should continue to be active since active node
      # executions were found in the last call to `orchestrate`.
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertTrue(execution_lib.is_execution_active(execution))

      # Call `orchestrate` again; this time there are no more active node
      # executions so the pipeline should be marked as cancelled.
      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertTrue(task_queue.is_empty())
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                       execution.last_known_state)

      # stop_node_services should be called on both ExampleGen and Transform
      # which are service nodes.
      mock_service_job_manager.stop_node_services.assert_has_calls(
          [mock.call(mock.ANY, 'ExampleGen'),
           mock.call(mock.ANY, 'Transform')],
          any_order=True)
Example #20
0
  def test_orchestrate_active_pipelines(self, mock_async_task_gen,
                                        mock_sync_task_gen):
    with self._mlmd_connection as m:
      # Sync and async active pipelines.
      async_pipelines = [
          _test_pipeline('pipeline1'),
          _test_pipeline('pipeline2'),
      ]
      sync_pipelines = [
          _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC),
          _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC),
      ]

      for pipeline in async_pipelines + sync_pipelines:
        pipeline_ops.initiate_pipeline_start(m, pipeline)

      # Active executions for active async pipelines.
      mock_async_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[0]),
                      node_id='Transform'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[1]),
                      node_id='Trainer'))
          ],
      ]

      # Active executions for active sync pipelines.
      mock_sync_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[0]),
                      node_id='Trainer'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[1]),
                      node_id='Validator'))
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())

      self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
      self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

      # Verify that tasks are enqueued in the expected order.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid)
      self.assertTrue(task_queue.is_empty())
Example #21
0
    def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active,
                                            mock_async_task_gen,
                                            mock_sync_task_gen):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline1_execution = pipeline_ops._initiate_pipeline_stop(
                m, task_lib.PipelineUid.from_pipeline(pipeline1))

            task_queue = tq.TaskQueue()

            # For the stop-initiated pipeline, "Transform" execution task is in queue,
            # "Trainer" has an active execution in MLMD but no task in queue,
            # "Evaluator" has no active execution.
            task_queue.enqueue(
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Transform')))
            transform_task = task_queue.dequeue(
            )  # simulates task being processed
            mock_gen_task_from_active.side_effect = [
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Trainer'),
                                                 is_cancelled=True), None,
                None, None, None
            ]

            pipeline_ops.generate_tasks(m, task_queue)

            # There are no active pipelines so these shouldn't be called.
            mock_async_task_gen.assert_not_called()
            mock_sync_task_gen.assert_not_called()

            # Simulate finishing the "Transform" ExecNodeTask.
            task_queue.task_done(transform_task)

            # CancelNodeTask for the "Transform" ExecNodeTask should be next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual('Transform', task.node_uid.node_id)

            # ExecNodeTask for "Trainer" is next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual('Trainer', task.node_uid.node_id)

            self.assertTrue(task_queue.is_empty())

            mock_gen_task_from_active.assert_has_calls([
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[1].pipeline_node,
                          mock.ANY,
                          is_cancelled=True),
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[2].pipeline_node,
                          mock.ANY,
                          is_cancelled=True)
            ])
            self.assertEqual(2, mock_gen_task_from_active.call_count)

            # Pipeline execution should continue to be active since active node
            # executions were found in the last call to `generate_tasks`.
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertTrue(execution_lib.is_execution_active(execution))

            # Call `generate_tasks` again; this time there are no more active node
            # executions so the pipeline should be marked as cancelled.
            pipeline_ops.generate_tasks(m, task_queue)
            self.assertTrue(task_queue.is_empty())
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                             execution.last_known_state)
Example #22
0
    def test_generate_tasks_async_active_pipelines(self, mock_async_task_gen,
                                                   mock_sync_task_gen):
        with self._mlmd_connection as m:
            # One active pipeline.
            pipeline1 = _test_pipeline('pipeline1')
            pipeline_ops.initiate_pipeline_start(m, pipeline1)

            # Another active pipeline (with previously completed execution).
            pipeline2 = _test_pipeline('pipeline2')
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)
            execution2.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution2])
            execution2 = pipeline_ops.initiate_pipeline_start(m, pipeline2)

            # Inactive pipelines should be ignored.
            pipeline3 = _test_pipeline('pipeline3')
            execution3 = pipeline_ops.initiate_pipeline_start(m, pipeline3)
            execution3.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution3])

            # For active pipelines pipeline1 and pipeline2, there are a couple of
            # active executions.
            def _exec_node_tasks():
                for pipeline_id in ('pipeline1', 'pipeline2'):
                    yield [
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Transform')),
                        test_utils.create_exec_node_task(
                            node_uid=task_lib.
                            NodeUid(pipeline_uid=task_lib.PipelineUid(
                                pipeline_id=pipeline_id, pipeline_run_id=None),
                                    node_id='Trainer'))
                    ]

            mock_async_task_gen.return_value.generate.side_effect = _exec_node_tasks(
            )

            task_queue = tq.TaskQueue()
            pipeline_ops.generate_tasks(m, task_queue)

            self.assertEqual(
                2, mock_async_task_gen.return_value.generate.call_count)
            mock_sync_task_gen.assert_not_called()

            # Verify that tasks are enqueued in the expected order.
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline1',
                                 task.node_uid.pipeline_uid.pipeline_id)
            for node_id in ('Transform', 'Trainer'):
                task = task_queue.dequeue()
                task_queue.task_done(task)
                self.assertTrue(task_lib.is_exec_node_task(task))
                self.assertEqual(node_id, task.node_uid.node_id)
                self.assertEqual('pipeline2',
                                 task.node_uid.pipeline_uid.pipeline_id)
            self.assertTrue(task_queue.is_empty())
Example #23
0
    def test_active_pipelines_with_stop_initiated_nodes(
            self, mock_gen_task_from_active, mock_async_task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

            transform_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[0].pipeline_node)
            transform_task = test_utils.create_exec_node_task(
                node_uid=transform_node_uid)

            trainer_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[1].pipeline_node)
            trainer_task = test_utils.create_exec_node_task(
                node_uid=trainer_node_uid)

            evaluator_node_uid = task_lib.NodeUid.from_pipeline_node(
                pipeline, pipeline.nodes[2].pipeline_node)
            evaluator_task = test_utils.create_exec_node_task(
                node_uid=evaluator_node_uid)
            cancelled_evaluator_task = test_utils.create_exec_node_task(
                node_uid=evaluator_node_uid, is_cancelled=True)

            pipeline_ops.initiate_pipeline_start(m, pipeline)
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                # Stop trainer and evaluator.
                pipeline_state.initiate_node_stop(trainer_node_uid)
                pipeline_state.initiate_node_stop(evaluator_node_uid)

            task_queue = tq.TaskQueue()

            # Simulate a new transform execution being triggered.
            mock_async_task_gen.return_value.generate.return_value = [
                transform_task
            ]
            # Simulate ExecNodeTask for trainer already present in the task queue.
            task_queue.enqueue(trainer_task)
            # Simulate Evaluator having an active execution in MLMD.
            mock_gen_task_from_active.side_effect = [evaluator_task]

            pipeline_ops.orchestrate(m, task_queue)
            self.assertEqual(
                1, mock_async_task_gen.return_value.generate.call_count)

            # Verify that tasks are enqueued in the expected order:

            # Pre-existing trainer task.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertEqual(trainer_task, task)

            # CancelNodeTask for trainer.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual(trainer_node_uid, task.node_uid)

            # ExecNodeTask with is_cancelled=True for evaluator.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(cancelled_evaluator_task, task)

            # ExecNodeTask for newly triggered transform node.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertEqual(transform_task, task)

            # No more tasks.
            self.assertTrue(task_queue.is_empty())