def test_stop_node_wait_for_inactivation_timeout(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) trainer = pipeline.nodes[2].pipeline_node test_utils.fake_component_output( self._mlmd_connection, trainer, active=True) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline).commit() with self.assertRaisesRegex( status_lib.StatusNotOkError, 'Timed out.*waiting for execution inactivation.' ) as exception_context: pipeline_ops.stop_node(m, node_uid, timeout_secs=1.0) self.assertEqual(status_lib.Code.DEADLINE_EXCEEDED, exception_context.exception.code) # Even if `wait_for_inactivation` times out, the node should be stop # initiated to prevent future triggers. pipeline_state = pstate.PipelineState.load(m, pipeline_uid) self.assertEqual(status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code)
def test_stop_node_wait_for_inactivation(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) trainer = pipeline.nodes[2].pipeline_node test_utils.fake_component_output( self._mlmd_connection, trainer, active=True) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline).commit() def _inactivate(execution): time.sleep(2.0) with pipeline_ops._PIPELINE_OPS_LOCK: execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) execution = task_gen_utils.get_executions(m, trainer)[0] thread = threading.Thread( target=_inactivate, args=(copy.deepcopy(execution),)) thread.start() pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0) thread.join() pipeline_state = pstate.PipelineState.load(m, pipeline_uid) self.assertEqual(status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code) # Restart node. pipeline_state = pipeline_ops.initiate_node_start(m, node_uid) self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
def test_generate_task_from_active_execution(self): with self._mlmd_connection as m: # No tasks generated without active execution. executions = task_gen_utils.get_executions(m, self._trainer) self.assertIsNone( task_gen_utils.generate_task_from_active_execution( m, self._pipeline, self._trainer, executions)) # Next, ensure an active execution for trainer. otu.fake_component_output(self._mlmd_connection, self._trainer) with self._mlmd_connection as m: execution = m.store.get_executions()[0] execution.last_known_state = metadata_store_pb2.Execution.RUNNING m.store.put_executions([execution]) # Check that task can be generated. executions = task_gen_utils.get_executions(m, self._trainer) task = task_gen_utils.generate_task_from_active_execution( m, self._pipeline, self._trainer, executions) self.assertEqual(execution.id, task.execution_id) # Mark execution complete. No tasks should be generated. execution = m.store.get_executions()[0] execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) executions = task_gen_utils.get_executions(m, self._trainer) self.assertIsNone( task_gen_utils.generate_task_from_active_execution( m, self._pipeline, self._trainer, executions))
def test_get_executions(self): with self._mlmd_connection as m: for node in [n.pipeline_node for n in self._pipeline.nodes]: self.assertEmpty(task_gen_utils.get_executions(m, node)) # Create executions for the same nodes under different pipeline contexts. self._set_pipeline_context('my_pipeline1') otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1) otu.fake_component_output(self._mlmd_connection, self._transform) self._set_pipeline_context('my_pipeline2') otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 1, 1) otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1) otu.fake_component_output(self._mlmd_connection, self._transform) # Get all executions across all pipeline contexts. with self._mlmd_connection as m: all_eg_execs = sorted(m.store.get_executions_by_type( self._example_gen.node_info.type.name), key=lambda e: e.id) all_transform_execs = sorted(m.store.get_executions_by_type( self._transform.node_info.type.name), key=lambda e: e.id) # Check that correct executions are returned for each node in each pipeline. self._set_pipeline_context('my_pipeline1') with self._mlmd_connection as m: self.assertCountEqual( all_eg_execs[0:2], task_gen_utils.get_executions(m, self._example_gen)) self.assertCountEqual( all_transform_execs[0:1], task_gen_utils.get_executions(m, self._transform)) self.assertEmpty(task_gen_utils.get_executions(m, self._trainer)) self._set_pipeline_context('my_pipeline2') with self._mlmd_connection as m: self.assertCountEqual( all_eg_execs[2:], task_gen_utils.get_executions(m, self._example_gen)) self.assertCountEqual( all_transform_execs[1:], task_gen_utils.get_executions(m, self._transform)) self.assertEmpty(task_gen_utils.get_executions(m, self._trainer))
def test_get_latest_successful_execution(self): otu.fake_component_output(self._mlmd_connection, self._transform) otu.fake_component_output(self._mlmd_connection, self._transform) otu.fake_component_output(self._mlmd_connection, self._transform) with self._mlmd_connection as m: execs = sorted(m.store.get_executions(), key=lambda e: e.id) execs[2].last_known_state = metadata_store_pb2.Execution.FAILED m.store.put_executions([execs[2]]) execs = sorted(task_gen_utils.get_executions(m, self._transform), key=lambda e: e.id) self.assertEqual( execs[1], task_gen_utils.get_latest_successful_execution(execs))
def _finish_node_execution(self, use_task_queue, node, execution): """Simulates successful execution of a node.""" otu.fake_component_output(self._mlmd_connection, node, execution) self._dequeue_and_test(use_task_queue, node, execution.id)