def WaitCMD(self, request, context): """Wait for a command in DebuggerCache.""" # check if graph have already received. log.info("Received WaitCMD at %s-th step.", request.cur_step) if self._status == ServerStatus.PENDING: log.warning("No graph received before WaitCMD.") reply = get_ack_reply(1) return reply self._send_received_tensor_tag() # send graph if has not been sent before self._pre_process(request) # deal with old command reply = self._deal_with_old_command() if reply: log.info("Reply to WaitCMD with old command: %s", reply) return reply # continue multiple steps training if self._continue_steps: reply = get_ack_reply() reply.run_cmd.run_steps = 1 reply.run_cmd.run_level = 'step' self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() log.debug("Send RunCMD. Clean watchpoint hit.") # wait for command else: reply = self._wait_for_next_command() if reply is None: reply = get_ack_reply(1) log.warning("Failed to get command event.") else: log.info("Reply to WaitCMD: %s", reply) return reply
def SendHeartbeat(self, request, context): """Deal with heartbeat sent from training client.""" # only support single card training now if self._heartbeat is None or not self._heartbeat.is_alive(): period = request.period min_period_seconds = 5 max_period_seconds = 3600 if not min_period_seconds <= period <= max_period_seconds: log.error("Invalid period time which should be in [5, 3600].") return get_ack_reply(-1) self._heartbeat = HeartbeatListener(self._cache_store, request.period) self._heartbeat.start() self._heartbeat.send() return get_ack_reply()
def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_construct = [] tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) tensor_names = [] step = metadata_stream.step for tensor in request_iterator: tensor_construct.append(tensor) if tensor.finished: update_flag = tensor_stream.put({ 'step': step, 'tensor_protos': tensor_construct }) if self._received_view_cmd.get( 'wait_for_tensor') and update_flag: # update_flag is used to avoid querying empty tensors again self._received_view_cmd['wait_for_tensor'] = False log.debug("Set wait for tensor flag to False.") tensor_construct = [] tensor_names.append(':'.join([tensor.node_name, tensor.slot])) continue reply = get_ack_reply() return reply
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left run cmd %s change to emtpy.", self._old_run_cmd) self._old_run_cmd.clear() if self._cache_store.get_stream_handler( Streams.METADATA).state == ServerStatus.RUNNING.value: # if the client session is running a script, all the cached command should be cleared # when received watchpoint_hits. self._cache_store.clean_command() # save the watchpoint_hits data watchpoint_hits = [] watchpoint_stream = self._cache_store.get_stream_handler( Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: node_full_name = watchpoint_hit_proto.tensor.node_name graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) continue ui_node_name = graph_stream.get_node_name_by_full_name( node_full_name, graph_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': copy.deepcopy( watchpoint_stream.get_watchpoint_by_id( watchpoint_hit_proto.id)), 'node_name': ui_node_name, 'graph_name': graph_name } hit_params = {} for param in watchpoint_hit_proto.watch_condition.params: if param.actual_value: hit_params[param.name] = param.actual_value for i, param in enumerate( watchpoint_hit['watchpoint'].condition['params']): name = param['name'] if name in hit_params.keys(): watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = hit_params[name] else: watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = None if watchpoint_hit_proto.error_code: watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code watchpoint_hits.append(watchpoint_hit) self._received_hit = watchpoint_hits reply = get_ack_reply() return reply
def test_send_watchpoint_hit(self, *args): """Test SendWatchpointHits interface.""" args[0].side_effect = [None, 'mock_full_name'] watchpoint_hit = MockDataGenerator.get_watchpoint_hit() res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit], MagicMock()) assert res == get_ack_reply()
def SendGraph(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received graph.") reply = get_ack_reply() if self._status == ServerStatus.MISMATCH: log.info( "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service." ) return reply serial_graph = b"" for chunk in request_iterator: serial_graph += chunk.buffer graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) graph_dict = {graph.name: graph} self._cache_store.get_stream_handler(Streams.GRAPH).put( {0: graph_dict}) self._cache_store.get_stream_handler(Streams.GRAPH).parse_stack_infos() self._cache_store.get_stream_handler( Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( graph.const_vals) self._cache_store.get_stream_handler( Streams.METADATA).graph_name = graph.name self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply
def test_send_matadata_with_mismatched(self): """Test SendMatadata interface.""" res = self._server.SendMetadata( MagicMock(training_done=False, ms_version='1.0.0'), MagicMock()) expect_reply = get_ack_reply() expect_reply.version_matched = False assert res == expect_reply
def test_waitcmd_with_next_command_is_none(self, *args): """Test wait command interface with next command is None.""" args[0].return_value = None setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH) res = self._server.WaitCMD(MagicMock(cur_step=1, cur_node=''), MagicMock()) assert res == get_ack_reply(1)
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) self._continue_steps = 0 watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: ui_node_name = graph_stream.get_node_name_by_full_name( watchpoint_hit_proto.tensor.node_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), 'node_name': ui_node_name } watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hits_info = watchpoint_hit_stream.get() self._cache_store.put_data(watchpoint_hits_info) log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") reply = get_ack_reply() return reply
def SendMultiGraphs(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received multi_graphs.") reply = get_ack_reply() if self._status == ServerStatus.MISMATCH: log.info( "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service." ) return reply serial_graph = b"" graph_dict = {} for chunk in request_iterator: serial_graph += chunk.buffer if chunk.finished: sub_graph = GraphProto.FromString(serial_graph) graph_dict[sub_graph.name] = sub_graph log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, len(sub_graph.node)) serial_graph = b"" self._cache_store.get_stream_handler( Streams.TENSOR).put_const_vals(sub_graph.const_vals) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply
def get_view_cmd(): """Get set command""" view_event = get_ack_reply() ms_tensor = view_event.view_cmd.tensors.add() ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0' event = {'view_cmd': view_event, 'node_name': 'mock_node_name', 'graph_name': 'mock_graph_name'} return event
def get_run_cmd(steps=0, level='step', node_name=''): """Get run command.""" event = get_ack_reply() event.run_cmd.run_level = level if level == 'node': event.run_cmd.node_name = node_name else: event.run_cmd.run_steps = steps return event
def stop(self): """Stop server.""" self._is_running_flag = False self._command_listener.stop() self._cache_store.clean() event = get_ack_reply() event.exit = True self._cache_store.put_command(event) log.info("Stop debugger offline manager.")
def _send_watchpoints(self): """Send watchpoints to client.""" set_commands = self._watchpoint_stream.get_pending_commands( self._graph_stream) if not set_commands: return for set_cmd in set_commands: event = get_ack_reply() event.set_cmd.CopyFrom(set_cmd) self._cache_store.put_command(event) log.debug("Send SetCMD to MindSpore. %s", event)
def _terminate(self, metadata_stream): """ Terminate the training. Args: metadata_stream (MetadataHandler): The metadata stream handler. """ metadata_stream.state = 'pending' event = get_ack_reply() event.exit = True self.cache_store.put_command(event) log.debug("Send the ExitCMD.") return {'metadata': {'state': 'pending'}}
def _send_watchpoints(self): """Set watchpoints.""" watchpoint_stream = self.cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoints = watchpoint_stream.get( filter_condition=True).get('watch_points') if watchpoints: for watchpoint in watchpoints: event = get_ack_reply() event.set_cmd.CopyFrom(watchpoint) self.cache_store.put_command(event) watchpoint_stream.sync_set_cmd() log.debug("Send SetCMD to MindSpore. %s", event)
def SendGraph(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received graph.") serial_graph = b"" for chunk in request_iterator: serial_graph += chunk.buffer graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) self._status = ServerStatus.RECEIVE_GRAPH reply = get_ack_reply() log.info("Send the reply for graph.") return reply
def WaitCMD(self, request, context): """Wait for a command in DebuggerCache.""" # check if graph have already received. log.info("Received WaitCMD at %s-th step.", request.cur_step) if self._status == ServerStatus.PENDING: log.warning("No graph received before WaitCMD.") reply = get_ack_reply(1) return reply # send graph if it has not been sent before self._pre_process(request) # deal with old command reply = self._deal_with_old_command() # wait for next command if reply is None: reply = self._wait_for_next_command() # check the reply if reply is None: reply = get_ack_reply(1) log.warning("Failed to get command event.") else: log.debug("Reply to WaitCMD: %s", reply) return reply
def _pause(self, metadata_stream): """ Pause the training. Args: metadata_stream (MetadataHandler): The metadata stream handler. """ if metadata_stream.state != ServerStatus.RUNNING.value: log.error("The MindSpore is not running.") raise DebuggerPauseError("The MindSpore is not running.") metadata_stream.state = 'waiting' event = get_ack_reply() event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) self.cache_store.put_command(event) log.debug("Send the Pause command") return {'metadata': {'state': 'waiting'}}
def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_contents = [] tensor_stream = self._cache_store.get_stream_handler( Streams.TENSOR).get_tensor_handler_by_rank_id(0) metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) step = metadata_stream.step oversize = False node_info = self._received_view_cmd.get('node_info') if node_info and node_info.get('load'): load_info = node_info.get('load') load_info['graph_name'] = node_info.get('graph_name') load_info['node_name'] = node_info.get('node_name') # The rank_id of online debugger is 0 load_tensor(load_info=load_info, step=step, request_iterator=request_iterator, cache_store=self._cache_store, rank_id=0) else: for tensor in request_iterator: tensor_contents.append(tensor.tensor_content) if len(tensor_contents) >= MAX_SINGLE_TENSOR_CACHE or oversize: oversize = True tensor_contents = [] if tensor.finished: value = { 'step': step, 'tensor_proto': tensor, 'tensor_contents': tensor_contents, 'oversize': oversize, 'stats': bool(node_info and node_info.get('stats')) } update_flag = tensor_stream.put(value) if self._received_view_cmd.get( 'wait_for_tensor') and update_flag: # update_flag is used to avoid querying empty tensors again self._received_view_cmd['wait_for_tensor'] = False log.debug("Set wait for tensor flag to False.") tensor_contents = [] oversize = False continue reply = get_ack_reply() return reply
def terminate_training(self): """ Terminate the training. Returns: dict, metadata info. """ metadata_stream = self._metadata_stream metadata_stream.state = ServerStatus.SENDING.value self._cache_store.clean_data() self._cache_store.clean_command() event = get_ack_reply() event.exit = True self._cache_store.put_command(event) metadata_stream.enable_recheck = False log.debug("Send the ExitCMD.") return metadata_stream.get(['state', 'enable_recheck'])
def pause_training(self): """ Pause the training. Returns: dict, metadata info. """ metadata_stream = self._metadata_stream if metadata_stream.state != ServerStatus.RUNNING.value: log.error("The MindSpore is not running.") raise DebuggerPauseError("The MindSpore is not running.") metadata_stream.state = 'waiting' event = get_ack_reply() event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) self._cache_store.put_command(event) metadata_stream.enable_recheck = False log.debug("Send the Pause command") return metadata_stream.get(['state', 'enable_recheck'])
def SendMetadata(self, request, context): """Send metadata into DebuggerCache.""" log.info("Received Metadata.") if self._status != ServerStatus.PENDING: log.info("Re-initialize cache store when new session comes.") self.init() client_ip = context.peer().split(':', 1)[-1] metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) reply = get_ack_reply() if request.training_done: log.info("The training from %s has finished.", client_ip) else: ms_version = request.ms_version if not ms_version: ms_version = '1.0.x' if version_match(ms_version, mindinsight.__version__) is False: log.info( "Version is mismatched, mindspore is: %s, mindinsight is: %s", ms_version, mindinsight.__version__) self._status = ServerStatus.MISMATCH reply.version_matched = False metadata_stream.state = 'mismatch' else: log.info("version is matched.") reply.version_matched = True metadata_stream.debugger_version = { 'ms': ms_version, 'mi': mindinsight.__version__ } log.debug("Put ms_version from %s into cache.", client_ip) metadata_stream.put(request) metadata_stream.client_ip = client_ip log.debug("Put new metadata from %s into cache.", client_ip) # put metadata into data queue metadata = metadata_stream.get() self._cache_store.put_data(metadata) log.debug("Send the reply to %s.", client_ip) return reply
def _deal_with_left_continue_node(self, node_name): """ Construct run command with left continue nodes. Args: node_name (str): The target node name. Returns: Union[None, Event], the run command event. """ cur_full_name = self._cache_store.get_stream_handler(Streams.METADATA).full_name if cur_full_name == node_name: log.info("Execute to target node: %s", node_name) self._old_run_cmd.clear() return None event = get_ack_reply() event.run_cmd.run_level = 'node' event.run_cmd.node_name = '' log.debug("Send old node RunCMD, cur node: %s, target node: %s", cur_full_name, node_name) return event
def _deal_with_left_continue_step(self, left_step_count): """ Construct run command with left continue step count. Args: left_step_count (int): The count of left steps to be executed. Returns: Event, the run command event. """ event = get_ack_reply() event.run_cmd.run_steps = 1 event.run_cmd.run_level = 'step' left_step_count = left_step_count - 1 if left_step_count > 0 else -1 if not left_step_count: self._old_run_cmd.clear() else: self._old_run_cmd['left_step_count'] = left_step_count log.debug("Send old step RunCMD. Left step count: %s", left_step_count) return event
def _construct_run_event(self, params): """ Construct run cmd from input control params. Args: params (dict): The control params. - level (str): The control granularity, `node` level or `step` level. Default: `step`. - steps (int): Specify the steps that training should run. Used when `level` is `step`. - full_name (str): Specify the name of the node. Used when `level` is `node`. Returns: EventReply, control event with run command. """ level = params.get('level', 'step') event = get_ack_reply() if level == 'step': steps = params.get('steps') if not steps: steps = 1 run_cmd = RunCMD(run_level='step', run_steps=steps) elif level == 'node': self._validate_node_type(params.get('name')) name = self.cache_store.get_stream_handler( Streams.GRAPH).get_full_name(params['name']) if not name: name = '' run_cmd = RunCMD(run_level='node', node_name=name) else: log.error( "Invalid Value. `level` should be `step` or `node`. Got %s", level) raise DebuggerParamValueError("level` should be `step` or `node`") event.run_cmd.CopyFrom(run_cmd) log.debug("Construct run event. %s", event) return event
def SendMetadata(self, request, context): """Send metadata into DebuggerCache.""" log.info("Received Metadata.") if self._status != ServerStatus.PENDING: log.info("Re-initialize cache store when new session comes.") self.init() client_ip = context.peer().split(':', 1)[-1] metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) if request.training_done: log.info("The training from %s has finished.", client_ip) else: metadata_stream.put(request) metadata_stream.client_ip = client_ip log.info("Put new metadata from %s into cache.", client_ip) # put metadata into data queue metadata = metadata_stream.get() self._cache_store.put_data(metadata) reply = get_ack_reply() log.info("Send the reply to %s.", client_ip) return reply
def _construct_run_event(self, params): """ Construct run cmd from input control params. Args: params (dict): The control params. - level (str): The control granularity, `node`, `step` or `recheck` level. Default: `step`. - steps (int): Specify the steps that training should run. Used when `level` is `step`. - name (str): Specify the name of the node. Used when `level` is `node`. - graph_name (str): The graph name. Returns: EventReply, control event with run command. """ level = params.get('level', 'step') # construct run command events event = get_ack_reply() if level == 'step': steps = params.get('steps', 1) run_cmd = RunCMD(run_level='step', run_steps=steps) elif level == 'node': name = params.get('name', '') graph_name = params.get('graph_name') if name: rank_id = params.get('rank_id', 0) name = self._multi_card_graph_stream.get_graph_handler_by_rank_id( rank_id).get_full_name(name, graph_name) run_cmd = RunCMD(run_level='node', node_name=name) else: run_cmd = RunCMD(run_level='recheck') event.run_cmd.CopyFrom(run_cmd) log.debug("Construct run event. %s", event) return event
def get_set_cmd(): """Get set command""" event = get_ack_reply() event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1)) return event
def get_exit_cmd(): """Get exit command.""" event = get_ack_reply() event.exit = True return event