def run(self): """Start the debugger offline server.""" log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._offline_server_manager.initialize() log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._running.set() try_count = 0 while self._running.is_set( ) and try_count < self._MAX_TRY_EXCEPT_COUNT: try: self._offline_server_manager.wait_for_termination() if not self._offline_server_manager.is_runnable(): break except MindInsightException as err: log.exception(err) log.warning( "Error happens during listening on user commands. Restart listening again." ) finally: try_count += 1 # protect server from too much failure commands. if try_count == self._MAX_TRY_EXCEPT_COUNT: self._cache_store.clean() metadata = self._cache_store.get_stream_handler( Streams.METADATA).get() self._cache_store.put_data(metadata) log.warning("Exception exceed %d times, stop server.", try_count)
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left run cmd %s change to emtpy.", self._old_run_cmd) self._old_run_cmd.clear() if self._cache_store.get_stream_handler( Streams.METADATA).state == ServerStatus.RUNNING.value: # if the client session is running a script, all the cached command should be cleared # when received watchpoint_hits. self._cache_store.clean_command() # save the watchpoint_hits data watchpoint_hits = [] watchpoint_stream = self._cache_store.get_stream_handler( Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: node_full_name = watchpoint_hit_proto.tensor.node_name graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) continue ui_node_name = graph_stream.get_node_name_by_full_name( node_full_name, graph_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': copy.deepcopy( watchpoint_stream.get_watchpoint_by_id( watchpoint_hit_proto.id)), 'node_name': ui_node_name, 'graph_name': graph_name } hit_params = {} for param in watchpoint_hit_proto.watch_condition.params: if param.actual_value: hit_params[param.name] = param.actual_value for i, param in enumerate( watchpoint_hit['watchpoint'].condition['params']): name = param['name'] if name in hit_params.keys(): watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = hit_params[name] else: watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = None if watchpoint_hit_proto.error_code: watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code watchpoint_hits.append(watchpoint_hit) self._received_hit = watchpoint_hits reply = get_ack_reply() return reply
def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): """Add hit node info.""" graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id( rank_id) node_full_name = hit['name'] graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) return ui_node_name = graph_stream.get_node_name_by_full_name( node_full_name, graph_name) log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) return watchpoint_hit.update({ 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), 'node_name': ui_node_name, 'graph_name': graph_name })
def _pre_process(self, request): """Pre-process before dealing with command.""" # check if version is mismatch, if mismatch, send mismatch info to UI if self._status == ServerStatus.MISMATCH: log.warning("Version of Mindspore and Mindinsight re unmatched," "waiting for user to terminate the script.") metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) # put metadata into data queue metadata = metadata_stream.get(['state', 'debugger_version']) self._cache_store.put_data(metadata) return metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) is_new_step = metadata_stream.step < request.cur_step is_new_node = metadata_stream.full_name != request.cur_node # clean cache data at the beginning of new step or node has been changed. if is_new_step or is_new_node: self._cache_store.clean_data() self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( request.cur_step) if is_new_step: self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).clean() # receive graph at the beginning of the training if self._status == ServerStatus.RECEIVE_GRAPH: self._send_graph_flag(metadata_stream) # receive new metadata if is_new_step or is_new_node: self._update_metadata(metadata_stream, request) self._send_received_tensor_tag() self._send_watchpoint_hit_flag()
def get_bfs_graph(self, node_name, bfs_order): """ Traverse the graph in order of breath-first search. Returns: list, including the leaf nodes arranged in BFS order. """ temp_list = deque() temp_list.append(node_name) while temp_list: node_name = temp_list.popleft() node = self._leaf_nodes.get(node_name) if not node: log.warning('Cannot find node %s in graph. Ignored.', node_name) continue bfs_order.append(node_name) if node.inputs: for name in node.inputs.keys(): if name not in temp_list and name not in bfs_order: temp_list.append(name) if node.outputs: for name in node.outputs.keys(): if name not in temp_list and name not in bfs_order: temp_list.append(name)
def get_node_name_by_full_name(self, full_name): """Get node name by full names.""" inner_name = self._full_name_map_name.get(full_name, '') if not inner_name: log.warning("Node %s does not find the relative inner node name.", full_name) return inner_name
def _load_graphs(self): """Load graphs.""" # the format of graphs is a list of {'rank_id': int, 'graph_protos': [GraphProto]}} log.debug("Begin to load graphs.") graphs = self._data_loader.load_graphs() device_stream = self._cache_store.get_stream_handler(Streams.DEVICE) graph_per_rank = {} for graph in graphs: rank_id = graph.get('rank_id') graph_per_rank[rank_id] = {} tensor_stream_per_rank = self._cache_store.get_stream_handler(Streams.TENSOR). \ get_tensor_handler_by_rank_id(rank_id, create_if_not_exit=True) for graph_proto in graph.get('graph_protos'): graph_per_rank[rank_id][graph_proto.name] = graph_proto tensor_stream_per_rank.put_const_vals(graph_proto.const_vals) # the graph_per_rank is format like: Dict[<rank_id>, Dict[<graph_name>, <GraphProto>]] try: self._cache_store.get_stream_handler( Streams.GRAPH).put(graph_per_rank) self._cache_store.get_stream_handler( Streams.GRAPH).parse_stack_infos() device_stream.add_graph_name_info(graph_per_rank) except DebuggerParamValueError: log.warning("Parse graph failed. The graph file is invalid.") self._cache_store.get_stream_handler(Streams.GRAPH).clean() self._metadata_stream.state = ServerStatus.RECEIVE_GRAPH.value log.debug("Finish to load graphs.")
def get_valid_tensor_by_name(self, tensor_name, step, prev=False): """Get tensor value by name in numpy type.""" target_step = step - 1 if prev else step if target_step < 0: log.warning("Step %d has no previous value for tensor: %s", target_step, tensor_name) return None tensor = self._get_tensor(tensor_name, step=target_step) return tensor
def add_nodes(self, nodes): """Add node into watchpoint.""" if not nodes: log.warning("Add empty nodes.") return if not isinstance(nodes, list): nodes = [nodes] for node in nodes: self._watch_node.add_node(node.name, node.type, node.full_name)
def _check_session_num(self): """Check the amount of sessions.""" session_limitation = self.MAX_OFFLINE_SESSION_NUM if self.ONLINE_SESSION_ID in self.sessions: session_limitation += 1 if len(self.sessions) >= session_limitation: logger.warning( 'Offline debugger session num %s is reach the limitation %s', len(self.sessions), session_limitation) raise DebuggerSessionNumOverBoundError()
def get_valid_tensor_by_name(self, tensor_name, prev=False): """Get tensor value by name in numpy type.""" step = self.prev_step if prev else self.cur_step if step < 0: log.warning("%d step has no previous value for tensor: %s", self.cur_step, tensor_name) return None tensor = self._get_tensor(tensor_name, step=step) if tensor and tensor.empty: log.warning("%s has empty value.", tensor_name) return None return tensor
def add_nodes(self, nodes, rank_id): """Add node into watchpoint.""" if not nodes: log.warning("Add empty nodes.") return if rank_id not in self._watch_node: self._watch_node[rank_id] = WatchNodeTree() if not isinstance(nodes, list): nodes = [nodes] for node in nodes: watch_node = self._watch_node.get(rank_id) watch_node.add_node(node.name, node.type, node.full_name)
def get_tensor_value_by_shape(self, shape=None): """ Get tensor value by shape. Args: shape (tuple): The specified shape. Returns: Union[None, str, int, float], the value of parsed tensor. """ if shape: log.warning("Invalid shape for const value.") return self._value
def _put_tensor_value_into_cache(self, cur_step, node_info, rank_id, tensor_protos): """Put tensor value into tensor cache.""" tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR). \ get_tensor_handler_by_rank_id(rank_id) update_data_flag = False for tensor_proto in tensor_protos: if not tensor_proto.tensor_content: log.warning("Tensor %s:%s is empty.", tensor_proto.node_name, tensor_proto.slot) try: load_info = node_info.get('load') if load_info is not None: load_info['graph_name'] = node_info.get('graph_name') load_info['node_name'] = node_info.get('node_name') load_tensor(load_info=load_info, step=cur_step, request_iterator=iter([tensor_proto]), cache_store=self._cache_store, rank_id=rank_id) oversize = len( tensor_proto.tensor_content) > MAX_SINGLE_TENSOR_CACHE value = { 'step': cur_step, 'tensor_proto': tensor_proto, 'tensor_contents': [tensor_proto.tensor_content] if not oversize else [], 'stats': node_info.get('stats', False), 'oversize': oversize, } has_update = tensor_stream.put(value) except ValueError as err: log.warning("Failed to put %s:%s into cache. Ignore it. %s", tensor_proto.node_name, tensor_proto.slot, str(err)) continue if has_update: update_data_flag = True if update_data_flag: # send message to frontend metadata = self._metadata_stream.get(['step', 'state']) ret = {'receive_tensor': node_info.copy()} ret.update(metadata) self._cache_store.put_data(ret)
def get_graph_id_by_full_name(self, node_name): """ Get graph id by full name. Args: node_name (str): The full name of the node. Returns: str, the graph name of the node. Raises: DebuggerNodeNotInGraphError: If can not find the node in all graphs. """ graph_id = self.graph_node_map.get(node_name) if node_name else None if not graph_id: log.warning("Failed to get graph id by full name: %s", node_name) return graph_id
def generate_value_from_proto(tensor_proto): """ Generate tensor value from proto. Args: tensor_proto (TensorProto): The tensor proto. Returns: Union[None, np.ndarray], the value of the tensor. """ fields = tensor_proto.value.ListFields() if len(fields) != 2: log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) for field_name, field_value in fields: if field_name != 'dtype': return field_value return None
def _load_json_file(file): """ Load json file content. Args: file (Path): The Path object. Returns: dict, the json content. """ if not file.is_file(): log.info("File <%s> is missing.", str(file)) return {} with file.open() as handler: try: return json.load(handler) except json.decoder.JSONDecodeError as err: log.warning("Failed to load json file %s. %s", str(file), str(err)) return {}
def _deal_with_old_command(self): """Deal with old command.""" event = None while self._cache_store.has_command(self._pos) and event is None: event = self._get_next_command() log.debug("Deal with old %s-th command:\n%s.", self._pos, event) # deal with continue run command if event is None and self._old_run_cmd: left_step_count = self._old_run_cmd.get('left_step_count') node_name = self._old_run_cmd.get('node_name') # node_name and left_step_count should not set at the same time if not (left_step_count or node_name) or (left_step_count and node_name): log.warning("Invalid old run command. %s", self._old_run_cmd) self._old_run_cmd.clear() return None if left_step_count: event = self._deal_with_left_continue_step(left_step_count) else: event = self._deal_with_left_continue_node(node_name) log.debug("Send old RunCMD.") return event
def _get_merged_tensor(tensor_protos): """ Merged list of parsed tensor value into one. Args: tensor_protos (list[TensorProto]): List of tensor proto. Returns: TensorProto, merged tensor proto. """ merged_tensor = tensor_protos[-1] if len(tensor_protos) > 1: tensor_value = bytes() for tensor_proto in tensor_protos: if not tensor_proto.tensor_content: log.warning("Doesn't find tensor value for %s:%s", tensor_proto.node_name, tensor_proto.slot) break tensor_value += tensor_proto.tensor_content merged_tensor.tensor_content = tensor_value log.debug("Merge multi tensor values into one.") return merged_tensor
def get_graph_protos_from_dir(graphs_dir): """ Get graph from graph directory. Args: graphs_dir (Path): The Path object of graph directory. Returns: list, list of 'GraphProto' object. """ graph_protos = [] pre_file_name = "ms_output_trace_code_graph_" for file_in_device in graphs_dir.iterdir(): file_name = file_in_device.name if file_name.startswith(pre_file_name) and file_name.endswith(".pb"): try: graph_proto = load_graph_from_file(file_in_device) except DecodeError: log.warning("Load graph failed. The graph file is invalid.") return [] graph_protos.append(graph_proto) return graph_protos
def get(self, filter_condition=None): """ Get the graph of specific node. Args: filter_condition (dict): - name (str): The full debug node name. - graph_name (str): The relative graph_name of the node. - single_node (bool): If True, return the graph from root to the specific node; else, return the sublayer of the graph. Default: False. Returns: dict, the metadata. """ try: self._graph_exists() except DebuggerGraphNotExistError: log.warning('The graph is empty. To view a graph, ' 'please start the training script first.') return {'graph': {}} graph = {} if filter_condition is None: filter_condition = {} graph = {'graph_names': self.graph_names} single_node = filter_condition.get('single_node', False) name = filter_condition.get('name') graph_name = filter_condition.get('graph_name') if single_node is True: nodes = self._get_single_node(name, graph_name) else: nodes = self._list_nodes(name, graph_name) graph.update(nodes) return {'graph': graph}
def run(self): """Function that should be called when thread started.""" self._running.set() log.info("Start listening for heartbeat.") try_count = 0 while self._running.is_set(): try: self._heartbeat_queue.get(timeout=self._period) # reset try count if received heartbeat try_count = 0 except Empty: try_count += 1 if try_count >= self._MAX_TRY_EXCEPT_COUNT: break log.info("Missing heartbeat. Try again.") log.warning("Missing heartbeat. Reset online session.") self._cache_store.clean() metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) # put metadata into data queue metadata = metadata_stream.get() self._cache_store.put_data(metadata)
def WaitCMD(self, request, context): """Wait for a command in DebuggerCache.""" # check if graph have already received. log.info("Received WaitCMD at %s-th step.", request.cur_step) if self._status == ServerStatus.PENDING: log.warning("No graph received before WaitCMD.") reply = get_ack_reply(1) return reply # send graph if it has not been sent before self._pre_process(request) # deal with old command reply = self._deal_with_old_command() # wait for next command if reply is None: reply = self._wait_for_next_command() # check the reply if reply is None: reply = get_ack_reply(1) log.warning("Failed to get command event.") else: log.debug("Reply to WaitCMD: %s", reply) return reply
def generate_value_from_proto(self, tensor_proto): """ Generate tensor value from proto. Args: tensor_proto (TensorProto): The tensor proto. Returns: Union[None, str, np.ndarray], the value of the tensor. """ fields = tensor_proto.value.ListFields() if len(fields) != 2: log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) tensor_value = None for field_obj, field_value in fields: if field_obj.name != 'dtype': tensor_value = field_value break if tensor_value is not None and self.dtype != self._STRING_TYPE: tensor_value = np.array(tensor_value, dtype=NUMPY_TYPE_MAP.get(self.dtype)) return tensor_value
def get_tensor_value_by_shape(self, shape=None): """ Get tensor value by shape. Args: shape (tuple): The specified shape. Returns: Union[None, str, numpy.ndarray], the value of parsed tensor. """ if self._value is None: log.warning("%s has no value yet.", self.name) return None if shape is None or not isinstance(shape, tuple): log.info("Get the whole tensor value with shape is %s", shape) return self._value if len(shape) != len(self.shape): log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) raise DebuggerParamValueError("Invalid shape. Shape unmatched.") try: value = self._value[shape] except IndexError as err: log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) log.exception(err) raise DebuggerParamValueError("Invalid shape. Shape unmatched.") if isinstance(value, np.ndarray): if value.size > self.max_number_data_show_on_ui: log.info( "The tensor size is %d, which is too large to show on UI.", value.size) value = "Too large to show." else: value = np.asarray(value) return value
def get_tensor_serializable_value_by_shape(self, shape=None): """Get tensor info with value.""" if shape is not None: log.warning("Invalid shape for const value.") return self._value
def validate_rank_id(self, rank_id): """Validate the rank id.""" if rank_id not in self._watch_node: log.warning("Rank_id not exist") return