Ejemplo n.º 1
0
  def test_stop_node_wait_for_inactivation_timeout(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()
      with self.assertRaisesRegex(
          status_lib.StatusNotOkError,
          'Timed out.*waiting for execution inactivation.'
      ) as exception_context:
        pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0)
      self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED,
                       exception_context.exception.code)

      # Even if `wait_for_inactivation` times out, the node should be stop
      # initiated to prevent future triggers.
      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)
Ejemplo n.º 2
0
  def test_stop_node_wait_for_inactivation(self):
    pipeline = pipeline_pb2.Pipeline()
    self.load_proto_from_text(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'),
        pipeline)
    trainer = pipeline.nodes[2].pipeline_node
    test_utils.fake_component_output(
        self._mlmd_connection, trainer, active=True)
    pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline)
    node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid)
    with self._mlmd_connection as m:
      pstate.PipelineState.new(m, pipeline).commit()

      def _inactivate(execution):
        time.sleep(2.0)
        with pipeline_ops._PIPELINE_OPS_LOCK:
          execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
          m.store.put_executions([execution])

      execution = task_gen_utils.get_executions(m, trainer)[0]
      thread = threading.Thread(
          target=_inactivate, args=(copy.deepcopy(execution),))
      thread.start()
      pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0)
      thread.join()

      pipeline_state = pstate.PipelineState.load(m, pipeline_uid)
      self.assertEqual(status_lib.Code.CANCELLED,
                       pipeline_state.node_stop_initiated_reason(node_uid).code)

      # Restart node.
      pipeline_state = pipeline_ops.initiate_node_start(m, node_uid)
      self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
Ejemplo n.º 3
0
    def test_generate_task_from_active_execution(self):
        with self._mlmd_connection as m:
            # No tasks generated without active execution.
            executions = task_gen_utils.get_executions(m, self._trainer)
            self.assertIsNone(
                task_gen_utils.generate_task_from_active_execution(
                    m, self._pipeline, self._trainer, executions))

        # Next, ensure an active execution for trainer.
        otu.fake_component_output(self._mlmd_connection, self._trainer)
        with self._mlmd_connection as m:
            execution = m.store.get_executions()[0]
            execution.last_known_state = metadata_store_pb2.Execution.RUNNING
            m.store.put_executions([execution])

            # Check that task can be generated.
            executions = task_gen_utils.get_executions(m, self._trainer)
            task = task_gen_utils.generate_task_from_active_execution(
                m, self._pipeline, self._trainer, executions)
            self.assertEqual(execution.id, task.execution_id)

            # Mark execution complete. No tasks should be generated.
            execution = m.store.get_executions()[0]
            execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
            m.store.put_executions([execution])
            executions = task_gen_utils.get_executions(m, self._trainer)
            self.assertIsNone(
                task_gen_utils.generate_task_from_active_execution(
                    m, self._pipeline, self._trainer, executions))
Ejemplo n.º 4
0
    def test_get_executions(self):
        with self._mlmd_connection as m:
            for node in [n.pipeline_node for n in self._pipeline.nodes]:
                self.assertEmpty(task_gen_utils.get_executions(m, node))

        # Create executions for the same nodes under different pipeline contexts.
        self._set_pipeline_context('my_pipeline1')
        otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                 1)
        otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2,
                                 1)
        otu.fake_component_output(self._mlmd_connection, self._transform)
        self._set_pipeline_context('my_pipeline2')
        otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1,
                                 1)
        otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2,
                                 1)
        otu.fake_component_output(self._mlmd_connection, self._transform)

        # Get all executions across all pipeline contexts.
        with self._mlmd_connection as m:
            all_eg_execs = sorted(m.store.get_executions_by_type(
                self._example_gen.node_info.type.name),
                                  key=lambda e: e.id)
            all_transform_execs = sorted(m.store.get_executions_by_type(
                self._transform.node_info.type.name),
                                         key=lambda e: e.id)

        # Check that correct executions are returned for each node in each pipeline.
        self._set_pipeline_context('my_pipeline1')
        with self._mlmd_connection as m:
            self.assertCountEqual(
                all_eg_execs[0:2],
                task_gen_utils.get_executions(m, self._example_gen))
            self.assertCountEqual(
                all_transform_execs[0:1],
                task_gen_utils.get_executions(m, self._transform))
            self.assertEmpty(task_gen_utils.get_executions(m, self._trainer))
        self._set_pipeline_context('my_pipeline2')
        with self._mlmd_connection as m:
            self.assertCountEqual(
                all_eg_execs[2:],
                task_gen_utils.get_executions(m, self._example_gen))
            self.assertCountEqual(
                all_transform_execs[1:],
                task_gen_utils.get_executions(m, self._transform))
            self.assertEmpty(task_gen_utils.get_executions(m, self._trainer))
Ejemplo n.º 5
0
 def test_get_latest_successful_execution(self):
     otu.fake_component_output(self._mlmd_connection, self._transform)
     otu.fake_component_output(self._mlmd_connection, self._transform)
     otu.fake_component_output(self._mlmd_connection, self._transform)
     with self._mlmd_connection as m:
         execs = sorted(m.store.get_executions(), key=lambda e: e.id)
         execs[2].last_known_state = metadata_store_pb2.Execution.FAILED
         m.store.put_executions([execs[2]])
         execs = sorted(task_gen_utils.get_executions(m, self._transform),
                        key=lambda e: e.id)
         self.assertEqual(
             execs[1],
             task_gen_utils.get_latest_successful_execution(execs))
Ejemplo n.º 6
0
 def _finish_node_execution(self, use_task_queue, node, execution):
   """Simulates successful execution of a node."""
   otu.fake_component_output(self._mlmd_connection, node, execution)
   self._dequeue_and_test(use_task_queue, node, execution.id)