示例#1
0
    def test_node_state_for_skipped_nodes_in_partial_pipeline_run(self):
        """Tests that nodes marked to be skipped in a partial pipeline run have the right node state."""
        with self._mlmd_connection as m:
            pipeline = pipeline_pb2.Pipeline()
            pipeline.pipeline_info.id = 'pipeline1'
            pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'

            # Mark ExampleGen and Transform to be skipped.
            pipeline.nodes[0].pipeline_node.execution_options.skip.SetInParent(
            )
            pipeline.nodes[1].pipeline_node.execution_options.skip.SetInParent(
            )

            eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
            transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
            trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')

            pstate.PipelineState.new(m, pipeline)
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    {
                        eg_node_uid:
                        pstate.NodeState(state=pstate.NodeState.COMPLETE),
                        transform_node_uid:
                        pstate.NodeState(state=pstate.NodeState.COMPLETE),
                        trainer_node_uid:
                        pstate.NodeState(state=pstate.NodeState.STARTED),
                    }, pipeline_state.get_node_states_dict())
示例#2
0
    def test_pipeline_view_get_node_run_states(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1', pipeline_pb2.Pipeline.SYNC)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Pusher'
            eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
            transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
            trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
            evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator')
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                with pipeline_state.node_state_update_context(
                        eg_node_uid) as node_state:
                    node_state.update(pstate.NodeState.RUNNING)
                with pipeline_state.node_state_update_context(
                        transform_node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)
                with pipeline_state.node_state_update_context(
                        trainer_node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTED)
                with pipeline_state.node_state_update_context(
                        evaluator_node_uid) as node_state:
                    node_state.update(
                        pstate.NodeState.FAILED,
                        status_lib.Status(code=status_lib.Code.ABORTED,
                                          message='foobar error'))

            [view] = pstate.PipelineView.load_all(
                m, task_lib.PipelineUid.from_pipeline(pipeline))
            run_states_dict = view.get_node_run_states()
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.RUNNING),
                run_states_dict['ExampleGen'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.UNKNOWN),
                run_states_dict['Transform'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.READY),
                run_states_dict['Trainer'])
            self.assertEqual(
                run_state_pb2.RunState(
                    state=run_state_pb2.RunState.FAILED,
                    status_code=run_state_pb2.RunState.StatusCodeValue(
                        value=status_lib.Code.ABORTED),
                    status_msg='foobar error'), run_states_dict['Evaluator'])
            self.assertEqual(
                run_state_pb2.RunState(state=run_state_pb2.RunState.READY),
                run_states_dict['Pusher'])
示例#3
0
    def test_stop_node_no_active_executions(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        node_uid = task_lib.NodeUid(node_id='my_trainer',
                                    pipeline_uid=pipeline_uid)
        with self._mlmd_connection as m:
            pstate.PipelineState.new(m, pipeline)
            pipeline_ops.stop_node(m, node_uid)
            pipeline_state = pstate.PipelineState.load(m, pipeline_uid)

            # The node should be stop-initiated even when node is inactive to prevent
            # future triggers.
            with pipeline_state:
                self.assertEqual(
                    status_lib.Code.CANCELLED,
                    pipeline_state.node_stop_initiated_reason(node_uid).code)

            # Restart node.
            pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
            with pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))
示例#4
0
  def test_stop_node_wait_for_inactivation(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()

      def _inactivate(execution):
        time.sleep(2.0)
        with pipeline_ops._PIPELINE_OPS_LOCK:
          execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
          m.store.put_executions([execution])

      execution = task_gen_utils.get_executions(m, trainer)[0]
      thread = threading.Thread(
          target=_inactivate, args=(copy.deepcopy(execution),))
      thread.start()
      pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0)
      thread.join()

      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)

      # Restart node.
      pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
      self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
示例#5
0
 def test_task_ids(self):
   pipeline_uid = task_lib.PipelineUid(pipeline_id='pipeline')
   node_uid = task_lib.NodeUid(pipeline_uid=pipeline_uid, node_id='Trainer')
   exec_node_task = test_utils.create_exec_node_task(node_uid)
   self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
   cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
   self.assertEqual(('CancelNodeTask', node_uid), cancel_node_task.task_id)
示例#6
0
  def test_stop_node_wait_for_inactivation_timeout(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()
      with self.assertRaisesRegex(
          status_lib.StatusNotOkError,
          'Timed out.*waiting for execution inactivation.'
      ) as exception_context:
        pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0)
      self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
                       exception_context.exception.code)

      # Even if `wait_for_inactivation` times out, the node should be stop
      # initiated to prevent future triggers.
      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)
示例#7
0
 def test_scheduler_not_found(self):
     task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
         pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
         node_id='Transform'),
                                             pipeline=self._pipeline)
     with self.assertRaisesRegex(ValueError, 'No task scheduler found'):
         ts.TaskSchedulerRegistry.create_task_scheduler(
             mock.Mock(), self._pipeline, task)
示例#8
0
def _test_exec_node_task(node_id,
                         pipeline_id,
                         pipeline_run_id=None,
                         pipeline=None):
    node_uid = task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
        pipeline_id=pipeline_id, pipeline_run_id=pipeline_run_id),
                                node_id=node_id)
    return test_utils.create_exec_node_task(node_uid, pipeline=pipeline)
示例#9
0
 def test_node_uid_from_pipeline_node(self):
   pipeline = pipeline_pb2.Pipeline()
   pipeline.pipeline_info.id = 'pipeline'
   node = pipeline_pb2.PipelineNode()
   node.node_info.id = 'Trainer'
   self.assertEqual(
       task_lib.NodeUid(
           pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
           node_id='Trainer'),
       task_lib.NodeUid.from_pipeline_node(pipeline, node))
示例#10
0
文件: task_test.py 项目: kp425/tfx
 def test_task_ids(self):
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = task_lib.ExecNodeTask(node_uid=node_uid,
                                            execution_id=123)
     self.assertEqual(('ExecNodeTask', node_uid), exec_node_task.task_id)
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual(('CancelNodeTask', node_uid),
                      cancel_node_task.task_id)
示例#11
0
 def test_task_type_ids(self):
     self.assertEqual('ExecNodeTask', task_lib.ExecNodeTask.task_type_id())
     self.assertEqual('CancelNodeTask',
                      task_lib.CancelNodeTask.task_type_id())
     node_uid = task_lib.NodeUid(pipeline_id='pipeline',
                                 pipeline_run_id='run0',
                                 node_id='Trainer')
     exec_node_task = test_utils.create_exec_node_task(node_uid)
     self.assertEqual('ExecNodeTask', exec_node_task.task_type_id())
     cancel_node_task = task_lib.CancelNodeTask(node_uid=node_uid)
     self.assertEqual('CancelNodeTask', cancel_node_task.task_type_id())
示例#12
0
 def test_node_uid_from_pipeline_node(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.NodeUid(pipeline_uid=task_lib.PipelineUid(
             pipeline_id='pipeline', pipeline_run_id='run0'),
                          node_id='Trainer'),
         task_lib.NodeUid.from_pipeline_node(pipeline, node))
示例#13
0
    def test_register_using_node_type_name(self):
        # Register a fake task scheduler.
        ts.TaskSchedulerRegistry.register(constants.IMPORTER_NODE_TYPE,
                                          _FakeTaskScheduler)

        # Create a task and verify that the correct scheduler is instantiated.
        task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
            pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
            node_id='Importer'),
                                                pipeline=self._pipeline)
        task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
            mock.Mock(), self._pipeline, task)
        self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
示例#14
0
文件: task_test.py 项目: kp425/tfx
 def test_exec_node_task_create(self):
     pipeline = pipeline_pb2.Pipeline()
     pipeline.pipeline_info.id = 'pipeline'
     pipeline.runtime_spec.pipeline_run_id.field_value.string_value = 'run0'
     node = pipeline_pb2.PipelineNode()
     node.node_info.id = 'Trainer'
     self.assertEqual(
         task_lib.ExecNodeTask(node_uid=task_lib.NodeUid(
             pipeline_id='pipeline',
             pipeline_run_id='run0',
             node_id='Trainer'),
                               execution_id=123),
         task_lib.ExecNodeTask.create(pipeline, node, 123))
示例#15
0
  def test_register_using_executor_spec_type_url(self):
    # Register a fake task scheduler.
    ts.TaskSchedulerRegistry.register(self._spec_type_url, _FakeTaskScheduler)

    # Create a task and verify that the correct scheduler is instantiated.
    task = test_utils.create_exec_node_task(
        node_uid=task_lib.NodeUid(
            pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline'),
            node_id='Trainer'),
        pipeline=self._pipeline)
    task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
        mock.Mock(), self._pipeline, task)
    self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
示例#16
0
    def test_initiate_node_start_stop(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            node_uid = task_lib.NodeUid(
                node_id='Trainer',
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline))
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTING)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTING, node_state.state)

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTING, node_state.state)

                # Set node state to STOPPING.
                status = status_lib.Status(code=status_lib.Code.ABORTED,
                                           message='foo bar')
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STOPPING, status)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
                self.assertEqual(status, node_state.status)

            # Reload from MLMD and verify node is stopped.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STOPPING, node_state.state)
                self.assertEqual(status, node_state.status)

                # Set node state to STARTED.
                with pipeline_state.node_state_update_context(
                        node_uid) as node_state:
                    node_state.update(pstate.NodeState.STARTED)
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTED, node_state.state)

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                node_state = pipeline_state.get_node_state(node_uid)
                self.assertEqual(pstate.NodeState.STARTED, node_state.state)
示例#17
0
    def test_handling_finalize_node_task(self, task_gen):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline_ops.initiate_pipeline_start(m, pipeline)
            pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
            finalize_reason = status_lib.Status(code=status_lib.Code.ABORTED,
                                                message='foo bar')
            task_gen.return_value.generate.side_effect = [
                [
                    test_utils.create_exec_node_task(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Transform')),
                    task_lib.FinalizeNodeTask(node_uid=task_lib.NodeUid(
                        pipeline_uid=pipeline_uid, node_id='Trainer'),
                                              status=finalize_reason)
                ],
            ]

            task_queue = tq.TaskQueue()
            pipeline_ops.orchestrate(m, task_queue,
                                     service_jobs.DummyServiceJobManager())
            task_gen.return_value.generate.assert_called_once()
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual(
                test_utils.create_node_uid('pipeline1', 'Transform'),
                task.node_uid)

            # Load pipeline state and verify node stop initiation.
            with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
                self.assertEqual(
                    finalize_reason,
                    pipeline_state.node_stop_initiated_reason(
                        task_lib.NodeUid(pipeline_uid=pipeline_uid,
                                         node_id='Trainer')))
示例#18
0
 def test_get_node_states_dict(self):
     with self._mlmd_connection as m:
         pipeline = pipeline_pb2.Pipeline()
         pipeline.pipeline_info.id = 'pipeline1'
         pipeline.execution_mode = pipeline_pb2.Pipeline.SYNC
         pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
         pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
         pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
         pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
         pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'
         eg_node_uid = task_lib.NodeUid(pipeline_uid, 'ExampleGen')
         transform_node_uid = task_lib.NodeUid(pipeline_uid, 'Transform')
         trainer_node_uid = task_lib.NodeUid(pipeline_uid, 'Trainer')
         evaluator_node_uid = task_lib.NodeUid(pipeline_uid, 'Evaluator')
         with pstate.PipelineState.new(m, pipeline) as pipeline_state:
             with pipeline_state.node_state_update_context(
                     eg_node_uid) as node_state:
                 node_state.update(pstate.NodeState.COMPLETE)
             with pipeline_state.node_state_update_context(
                     transform_node_uid) as node_state:
                 node_state.update(pstate.NodeState.RUNNING)
             with pipeline_state.node_state_update_context(
                     trainer_node_uid) as node_state:
                 node_state.update(pstate.NodeState.STARTING)
         with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state:
             self.assertEqual(
                 {
                     eg_node_uid:
                     pstate.NodeState(state=pstate.NodeState.COMPLETE),
                     transform_node_uid:
                     pstate.NodeState(state=pstate.NodeState.RUNNING),
                     trainer_node_uid:
                     pstate.NodeState(state=pstate.NodeState.STARTING),
                     evaluator_node_uid:
                     pstate.NodeState(state=pstate.NodeState.STARTED),
                 }, pipeline_state.get_node_states_dict())
示例#19
0
    def test_initiate_node_start_stop(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            node_uid = task_lib.NodeUid(
                node_id='Trainer',
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline))
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                pipeline_state.initiate_node_start(node_uid)
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

                # Stop the node.
                status = status_lib.Status(code=status_lib.Code.ABORTED,
                                           message='foo bar')
                pipeline_state.initiate_node_stop(node_uid, status)
                self.assertEqual(
                    status,
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is stopped.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertEqual(
                    status,
                    pipeline_state.node_stop_initiated_reason(node_uid))

                # Restart node.
                pipeline_state.initiate_node_start(node_uid)
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))
示例#20
0
    def test_registration_and_creation(self):
        # Create a pipeline IR containing deployment config for testing.
        deployment_config = pipeline_pb2.IntermediateDeploymentConfig()
        executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec(
            class_path='trainer.TrainerExecutor')
        deployment_config.executor_specs['Trainer'].Pack(executor_spec)
        pipeline = pipeline_pb2.Pipeline()
        pipeline.deployment_config.Pack(deployment_config)

        # Register a fake task scheduler.
        spec_type_url = deployment_config.executor_specs['Trainer'].type_url
        ts.TaskSchedulerRegistry.register(spec_type_url, _FakeTaskScheduler)

        # Create a task and verify that the correct scheduler is instantiated.
        task = test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
            pipeline_id='pipeline', pipeline_run_id=None, node_id='Trainer'))
        task_scheduler = ts.TaskSchedulerRegistry.create_task_scheduler(
            mock.Mock(), pipeline, task)
        self.assertIsInstance(task_scheduler, _FakeTaskScheduler)
示例#21
0
    def test_initiate_node_start_stop(self):
        with self._mlmd_connection as m:
            pipeline = _test_pipeline('pipeline1')
            node_uid = task_lib.NodeUid(
                node_id='Trainer',
                pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline))
            with pstate.PipelineState.new(m, pipeline) as pipeline_state:
                pipeline_state.initiate_node_start(node_uid)
                self.assertFalse(
                    pipeline_state.is_node_stop_initiated(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertFalse(
                    pipeline_state.is_node_stop_initiated(node_uid))

                # Stop the node.
                pipeline_state.initiate_node_stop(node_uid)
                self.assertTrue(
                    pipeline_state.is_node_stop_initiated(node_uid))

            # Reload from MLMD and verify node is stopped.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertTrue(
                    pipeline_state.is_node_stop_initiated(node_uid))

                # Restart node.
                pipeline_state.initiate_node_start(node_uid)
                self.assertFalse(
                    pipeline_state.is_node_stop_initiated(node_uid))

            # Reload from MLMD and verify node is started.
            with pstate.PipelineState.load(
                    m, task_lib.PipelineUid.from_pipeline(
                        pipeline)) as pipeline_state:
                self.assertFalse(
                    pipeline_state.is_node_stop_initiated(node_uid))
示例#22
0
def create_node_uid(pipeline_id, node_id):
  """Creates node uid."""
  return task_lib.NodeUid(
      pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
      node_id=node_id)
示例#23
0
def _test_cancel_task(node_id, pipeline_id, pipeline_run_id=None):
    node_uid = task_lib.NodeUid(pipeline_id=pipeline_id,
                                pipeline_run_id=pipeline_run_id,
                                node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)
示例#24
0
def _test_cancel_node_task(node_id, pipeline_id):
    node_uid = task_lib.NodeUid(
        pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id),
        node_id=node_id)
    return task_lib.CancelNodeTask(node_uid=node_uid)
示例#25
0
    def test_stop_initiated_async_pipelines(self, mock_gen_task_from_active,
                                            mock_async_task_gen,
                                            mock_sync_task_gen):
        with self._mlmd_connection as m:
            pipeline1 = _test_pipeline('pipeline1')
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Transform'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Trainer'
            pipeline1.nodes.add().pipeline_node.node_info.id = 'Evaluator'
            pipeline_ops.initiate_pipeline_start(m, pipeline1)
            pipeline1_execution = pipeline_ops._initiate_pipeline_stop(
                m, task_lib.PipelineUid.from_pipeline(pipeline1))

            task_queue = tq.TaskQueue()

            # For the stop-initiated pipeline, "Transform" execution task is in queue,
            # "Trainer" has an active execution in MLMD but no task in queue,
            # "Evaluator" has no active execution.
            task_queue.enqueue(
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Transform')))
            transform_task = task_queue.dequeue(
            )  # simulates task being processed
            mock_gen_task_from_active.side_effect = [
                test_utils.create_exec_node_task(node_uid=task_lib.NodeUid(
                    pipeline_uid=task_lib.PipelineUid(pipeline_id='pipeline1',
                                                      pipeline_run_id=None),
                    node_id='Trainer'),
                                                 is_cancelled=True), None,
                None, None, None
            ]

            pipeline_ops.generate_tasks(m, task_queue)

            # There are no active pipelines so these shouldn't be called.
            mock_async_task_gen.assert_not_called()
            mock_sync_task_gen.assert_not_called()

            # Simulate finishing the "Transform" ExecNodeTask.
            task_queue.task_done(transform_task)

            # CancelNodeTask for the "Transform" ExecNodeTask should be next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_cancel_node_task(task))
            self.assertEqual('Transform', task.node_uid.node_id)

            # ExecNodeTask for "Trainer" is next.
            task = task_queue.dequeue()
            task_queue.task_done(task)
            self.assertTrue(task_lib.is_exec_node_task(task))
            self.assertEqual('Trainer', task.node_uid.node_id)

            self.assertTrue(task_queue.is_empty())

            mock_gen_task_from_active.assert_has_calls([
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[1].pipeline_node,
                          mock.ANY,
                          is_cancelled=True),
                mock.call(m,
                          pipeline1,
                          pipeline1.nodes[2].pipeline_node,
                          mock.ANY,
                          is_cancelled=True)
            ])
            self.assertEqual(2, mock_gen_task_from_active.call_count)

            # Pipeline execution should continue to be active since active node
            # executions were found in the last call to `generate_tasks`.
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertTrue(execution_lib.is_execution_active(execution))

            # Call `generate_tasks` again; this time there are no more active node
            # executions so the pipeline should be marked as cancelled.
            pipeline_ops.generate_tasks(m, task_queue)
            self.assertTrue(task_queue.is_empty())
            [execution
             ] = m.store.get_executions_by_id([pipeline1_execution.id])
            self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                             execution.last_known_state)
示例#26
0
 def _update_successful_nodes_cache(self, node_ids: Set[str]) -> None:
     for node_id in node_ids:
         node_uid = task_lib.NodeUid(pipeline_uid=self._pipeline_uid,
                                     node_id=node_id)
         _successful_nodes_cache[self._node_cache_key(node_uid)] = True
示例#27
0
def _test_task(node_id, pipeline_id, pipeline_run_id=None):
    node_uid = task_lib.NodeUid(pipeline_id=pipeline_id,
                                pipeline_run_id=pipeline_run_id,
                                node_id=node_id)
    return task_lib.ExecNodeTask(node_uid=node_uid, execution_id=123)
示例#28
0
  def test_orchestrate_active_pipelines(self, mock_async_task_gen,
                                        mock_sync_task_gen):
    with self._mlmd_connection as m:
      # Sync and async active pipelines.
      async_pipelines = [
          _test_pipeline('pipeline1'),
          _test_pipeline('pipeline2'),
      ]
      sync_pipelines = [
          _test_pipeline('pipeline3', pipeline_pb2.Pipeline.SYNC),
          _test_pipeline('pipeline4', pipeline_pb2.Pipeline.SYNC),
      ]

      for pipeline in async_pipelines + sync_pipelines:
        pipeline_ops.initiate_pipeline_start(m, pipeline)

      # Active executions for active async pipelines.
      mock_async_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[0]),
                      node_id='Transform'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          async_pipelines[1]),
                      node_id='Trainer'))
          ],
      ]

      # Active executions for active sync pipelines.
      mock_sync_task_gen.return_value.generate.side_effect = [
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[0]),
                      node_id='Trainer'))
          ],
          [
              test_utils.create_exec_node_task(
                  task_lib.NodeUid(
                      pipeline_uid=task_lib.PipelineUid.from_pipeline(
                          sync_pipelines[1]),
                      node_id='Validator'))
          ],
      ]

      task_queue = tq.TaskQueue()
      pipeline_ops.orchestrate(m, task_queue,
                               service_jobs.DummyServiceJobManager())

      self.assertEqual(2, mock_async_task_gen.return_value.generate.call_count)
      self.assertEqual(2, mock_sync_task_gen.return_value.generate.call_count)

      # Verify that tasks are enqueued in the expected order.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline1', 'Transform'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline2', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline3', 'Trainer'), task.node_uid)
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual(
          test_utils.create_node_uid('pipeline4', 'Validator'), task.node_uid)
      self.assertTrue(task_queue.is_empty())
示例#29
0
  def test_stop_initiated_pipelines(self, pipeline, mock_gen_task_from_active,
                                    mock_async_task_gen, mock_sync_task_gen):
    with self._mlmd_connection as m:
      pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Transform'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Trainer'
      pipeline.nodes.add().pipeline_node.node_info.id = 'Evaluator'

      mock_service_job_manager = mock.create_autospec(
          service_jobs.ServiceJobManager, instance=True)
      mock_service_job_manager.is_pure_service_node.side_effect = (
          lambda _, node_id: node_id == 'ExampleGen')
      mock_service_job_manager.is_mixed_service_node.side_effect = (
          lambda _, node_id: node_id == 'Transform')

      pipeline_ops.initiate_pipeline_start(m, pipeline)
      with pstate.PipelineState.load(
          m, task_lib.PipelineUid.from_pipeline(pipeline)) as pipeline_state:
        pipeline_state.initiate_stop(
            status_lib.Status(code=status_lib.Code.CANCELLED))
      pipeline_execution = pipeline_state.execution

      task_queue = tq.TaskQueue()

      # For the stop-initiated pipeline, "Transform" execution task is in queue,
      # "Trainer" has an active execution in MLMD but no task in queue,
      # "Evaluator" has no active execution.
      task_queue.enqueue(
          test_utils.create_exec_node_task(
              task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Transform')))
      transform_task = task_queue.dequeue()  # simulates task being processed
      mock_gen_task_from_active.side_effect = [
          test_utils.create_exec_node_task(
              node_uid=task_lib.NodeUid(
                  pipeline_uid=task_lib.PipelineUid.from_pipeline(pipeline),
                  node_id='Trainer'),
              is_cancelled=True), None, None, None, None
      ]

      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)

      # There are no active pipelines so these shouldn't be called.
      mock_async_task_gen.assert_not_called()
      mock_sync_task_gen.assert_not_called()

      # stop_node_services should be called for ExampleGen which is a pure
      # service node.
      mock_service_job_manager.stop_node_services.assert_called_once_with(
          mock.ANY, 'ExampleGen')
      mock_service_job_manager.reset_mock()

      task_queue.task_done(transform_task)  # Pop out transform task.

      # CancelNodeTask for the "Transform" ExecNodeTask should be next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_cancel_node_task(task))
      self.assertEqual('Transform', task.node_uid.node_id)

      # ExecNodeTask (with is_cancelled=True) for "Trainer" is next.
      task = task_queue.dequeue()
      task_queue.task_done(task)
      self.assertTrue(task_lib.is_exec_node_task(task))
      self.assertEqual('Trainer', task.node_uid.node_id)
      self.assertTrue(task.is_cancelled)

      self.assertTrue(task_queue.is_empty())

      mock_gen_task_from_active.assert_has_calls([
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[2].pipeline_node,
              mock.ANY,
              is_cancelled=True),
          mock.call(
              m,
              pipeline_state.pipeline,
              pipeline.nodes[3].pipeline_node,
              mock.ANY,
              is_cancelled=True)
      ])
      self.assertEqual(2, mock_gen_task_from_active.call_count)

      # Pipeline execution should continue to be active since active node
      # executions were found in the last call to `orchestrate`.
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertTrue(execution_lib.is_execution_active(execution))

      # Call `orchestrate` again; this time there are no more active node
      # executions so the pipeline should be marked as cancelled.
      pipeline_ops.orchestrate(m, task_queue, mock_service_job_manager)
      self.assertTrue(task_queue.is_empty())
      [execution] = m.store.get_executions_by_id([pipeline_execution.id])
      self.assertEqual(metadata_store_pb2.Execution.CANCELED,
                       execution.last_known_state)

      # stop_node_services should be called on both ExampleGen and Transform
      # which are service nodes.
      mock_service_job_manager.stop_node_services.assert_has_calls(
          [mock.call(mock.ANY, 'ExampleGen'),
           mock.call(mock.ANY, 'Transform')],
          any_order=True)
示例#30
0
def _test_task(node_id, pipeline_id, key=''):
  node_uid = task_lib.NodeUid(
      pipeline_uid=task_lib.PipelineUid(pipeline_id=pipeline_id, key=key),
      node_id=node_id)
  return test_utils.create_exec_node_task(node_uid)