def test_stop_node_wait_for_inactivation(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join( os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) trainer = pipeline.nodes[2].pipeline_node test_utils.fake_component_output( self._mlmd_connection, trainer, active=True) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline).commit() def _inactivate(execution): time.sleep(2.0) with pipeline_ops._PIPELINE_OPS_LOCK: execution.last_known_state = metadata_store_pb2.Execution.COMPLETE m.store.put_executions([execution]) execution = task_gen_utils.get_executions(m, trainer)[0] thread = threading.Thread( target=_inactivate, args=(copy.deepcopy(execution),)) thread.start() pipeline_ops.stop_node(m, node_uid, timeout_secs=5.0) thread.join() pipeline_state = pstate.PipelineState.load(m, pipeline_uid) self.assertEqual(status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code) # Restart node. pipeline_state = pipeline_ops.initiate_node_start(m, node_uid) self.assertIsNone(pipeline_state.node_stop_initiated_reason(node_uid))
def test_stop_node_no_active_executions(self): pipeline = pipeline_pb2.Pipeline() self.load_proto_from_text( os.path.join(os.path.dirname(__file__), 'testdata', 'async_pipeline.pbtxt'), pipeline) pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) node_uid = task_lib.NodeUid(node_id='my_trainer', pipeline_uid=pipeline_uid) with self._mlmd_connection as m: pstate.PipelineState.new(m, pipeline) pipeline_ops.stop_node(m, node_uid) pipeline_state = pstate.PipelineState.load(m, pipeline_uid) # The node should be stop-initiated even when node is inactive to prevent # future triggers. with pipeline_state: self.assertEqual( status_lib.Code.CANCELLED, pipeline_state.node_stop_initiated_reason(node_uid).code) # Restart node. pipeline_state = pipeline_ops.initiate_node_start(m, node_uid) with pipeline_state: self.assertIsNone( pipeline_state.node_stop_initiated_reason(node_uid))