Exemplo n.º 1
0
  def test_stop_node_wait_for_inactivation(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()

      def _inactivate(execution):
        time.sleep(2.0)
        with pipeline_ops._PIPELINE_OPS_LOCK:
          execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
          m.store.put_executions([execution])

      execution = task_gen_utils.get_executions(m, trainer)[0]
      thread = threading.Thread(
          target=_inactivate, args=(copy.deepcopy(execution),))
      thread.start()
      pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0)
      thread.join()

      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)

      # Restart node.
      pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
      self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
Exemplo n.º 2
0
  def test_stop_node_wait_for_inactivation_timeout(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()
      with self.assertRaisesRegex(
          status_lib.StatusNotOkError,
          'Timed out.*waiting for execution inactivation.'
      ) as exception_context:
        pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0)
      self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
                       exception_context.exception.code)

      # Even if `wait_for_inactivation` times out, the node should be stop
      # initiated to prevent future triggers.
      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)
Exemplo n.º 3
0
    def test_stop_node_no_active_executions(self):
        pipeline = pipeline_pb2.Pipeline()
        self.load_proto_from_text(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'async_pipeline.pbtxt'), pipeline)
        pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
        node_uid = task_lib.NodeUid(node_id='my_trainer',
                                    pipeline_uid=pipeline_uid)
        with self._mlmd_connection as m:
            pstate.PipelineState.new(m, pipeline)
            pipeline_ops.stop_node(m, node_uid)
            pipeline_state = pstate.PipelineState.load(m, pipeline_uid)

            # The node should be stop-initiated even when node is inactive to prevent
            # future triggers.
            with pipeline_state:
                self.assertEqual(
                    status_lib.Code.CANCELLED,
                    pipeline_state.node_stop_initiated_reason(node_uid).code)

            # Restart node.
            pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
            with pipeline_state:
                self.assertIsNone(
                    pipeline_state.node_stop_initiated_reason(node_uid))