def get(self, filter_condition=None, rank_id=0): """Get the graph of specific node for specific device.""" if rank_id in self.watchpoint_hit_handlers: return self.watchpoint_hit_handlers.get(rank_id).get( filter_condition) log.error("There is no rank id %d.", rank_id) raise ValueError
def get(self, filter_condition=None): """ Get the watchpoints. Args: filter_condition (Union[None, int]): The filter conditions. Get watchpoint by id. If None, return all watchpoint. Default: None. Returns: dict, the watchpoint list. """ reply = [] if not filter_condition: # get watch condition list for _, watchpoint in self._watchpoints.items(): watchpoint_info = watchpoint.get_watch_condition_info() reply.append(watchpoint_info) else: self.validate_watchpoint_id(filter_condition) reply = [self._watchpoints.get(filter_condition)] log.debug("get the watch points with filter_condition:%s", filter_condition) return {'watch_points': reply}
def _get_watch_nodes_by_search(self, node_names, search_pattern, graph_name): """ Get watched leaf nodes by search name. Args: node_names (list[str]): A list of node names. search_pattern (dict): Get watch node with search pattern. - name (str): The name pattern. - node_category (str): The node_category. graph_name (str): The relative graph_name of the watched node. Returns: list[NodeBasicInfo], a list of node basic infos. """ search_pattern['graph_name'] = graph_name search_nodes = self._graph_stream.search_nodes(search_pattern) watch_node_names = set() for name in node_names: names = self._get_watch_names_by_search(search_nodes, name) watch_node_names.update(names) watch_node_info = self._get_node_basic_infos(watch_node_names, graph_name=graph_name) log.debug("Update nodes: %s", watch_node_info) return watch_node_info
def _retrieve_node(self, filter_condition): """ Retrieve node info. Args: filter_condition (dict): Filter condition. - name (str): The name of single node. - graph_name (str): The relative graph_name of the node. - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. - watch_point_id (int): The id of watchpoint. Returns: dict, reply with graph. """ log.debug("Retrieve node %s.", filter_condition) # validate node name node_name = filter_condition.get('name') graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_name = graph_stream.validate_graph_name( filter_condition.get('graph_name')) if node_name: # validate node name graph_stream.get_node_type(node_name, graph_name) filter_condition['single_node'] = bool( filter_condition.get('single_node')) filter_condition['graph_name'] = graph_name reply = self._get_nodes_info(filter_condition) return reply
def search_watchpoint_hits(self, group_condition): """ Retrieve watchpoint hit. Args: group_condition (dict): Filter condition. - limit (int): The limit of each page. - offset (int): The offset of current page. - node_name (str): The retrieved node name. - graph_name (str): The retrieved graph name. Returns: dict, watch point list or relative graph. """ if not isinstance(group_condition, dict): log.error( "Group condition for watchpoint-hits request should be a dict") raise DebuggerParamTypeError( "Group condition for watchpoint-hits request should be a dict") metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.state == ServerStatus.PENDING.value: log.info("The backend is in pending status.") return metadata_stream.get() reply = self.cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).group_by(group_condition) reply['outdated'] = self.cache_store.get_stream_handler( Streams.WATCHPOINT).is_recheckable() return reply
def add_node(self, node_name, node_type, full_name=''): """ Add watch node to watch node tree. Args: node_name (str): The node name. node_type (str): The node type. full_name (str): The full name of node. """ log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) scope_names = node_name.split('/', 1) if len(scope_names) == 1: target_node = self.get(node_name) if not target_node: self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) else: target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH) return scope_name, sub_names = scope_names sub_tree = self.get(scope_name) if not sub_tree: sub_tree = self.add(scope_name, watch_status=1) sub_tree.add_node(sub_names, node_type, full_name)
def stop(self): """Stop debugger server.""" log.info("Send terminate info to client.") self.control({'mode': 'terminate'}) self.grpc_server_manager.stop(grace=None) self.back_server.join() log.info("Stop debugger server.")
def _save_watchpoint_hits(self, hits): """Save watchpoint hits.""" multi_card_hit_streams = self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT) multi_card_graph_streams = self._cache_store.get_stream_handler( Streams.GRAPH) watchpoint_stream = self._cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoint_hits = defaultdict(list) for hit in hits: log.info( "Received hit\n: " "name:%s, slot:%s, condition:%s, " "watchpoint_id:%s" "error_code:%s, rank_id:%s", hit['name'], hit['slot'], hit['condition'], hit['watchpoint_id'], hit['error_code'], hit['rank_id']) rank_id = hit['rank_id'] watchpoint_hit = {} self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) if not watchpoint_hit: continue self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) watchpoint_hit['error_code'] = hit['error_code'] watchpoint_hits[rank_id].append(watchpoint_hit) # save hit info into cache multi_card_hit_streams.put(watchpoint_hits) self._cache_store.put_data({'receive_watchpoint_hits': True}) log.debug("Send the watchpoint hits to DataQueue.")
def _deal_with_run_cmd(self, event): """Deal with run cmd.""" metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) run_cmd = event.run_cmd # receive step command if run_cmd.run_level == RunLevel.STEP.value: # receive pause cmd if not run_cmd.run_steps: log.debug("Pause training and wait for next command.") self._old_run_cmd.clear() # update metadata state from sending to waiting metadata_stream.state = ServerStatus.WAITING.value return None # receive step cmd left_steps = run_cmd.run_steps - 1 event.run_cmd.run_steps = 1 if left_steps: self._old_run_cmd[ 'left_step_count'] = left_steps if left_steps > 0 else -1 elif run_cmd.node_name: self._old_run_cmd['node_name'] = run_cmd.node_name run_cmd.node_name = '' # clean watchpoint hit cache if run_cmd.run_level == RunLevel.RECHECK.value: self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).clean() log.debug("Receive RunCMD. Clean watchpoint hit cache.") # update metadata state from sending to running metadata_stream.state = ServerStatus.RUNNING.value return event
def _get_by_offset(self, group_condition): """Return the list of watchpoint hits on the offset page.""" limit = group_condition.get('limit') offset = group_condition.get('offset') if not isinstance(limit, int) or not isinstance(offset, int): log.error("Param limit or offset is not a integer") raise DebuggerParamValueError( "Param limit or offset is not a integer") watch_point_hits = [] total = len(self._ordered_hits) if limit * offset >= total and offset != 0: log.error("Param offset out of bounds") raise DebuggerParamValueError("Param offset out of bounds") if total == 0: return {} for watchpoint_hits in self._ordered_hits[(limit * offset):(limit * (offset + 1))]: self._get_tensors(watchpoint_hits, watch_point_hits) return { 'watch_point_hits': watch_point_hits, 'offset': offset, 'total': total }
def run(self): """Start the debugger offline server.""" log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._offline_server_manager.initialize() log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._running.set() try_count = 0 while self._running.is_set( ) and try_count < self._MAX_TRY_EXCEPT_COUNT: try: self._offline_server_manager.wait_for_termination() if not self._offline_server_manager.is_runnable(): break except MindInsightException as err: log.exception(err) log.warning( "Error happens during listening on user commands. Restart listening again." ) finally: try_count += 1 # protect server from too much failure commands. if try_count == self._MAX_TRY_EXCEPT_COUNT: self._cache_store.clean() metadata = self._cache_store.get_stream_handler( Streams.METADATA).get() self._cache_store.put_data(metadata) log.warning("Exception exceed %d times, stop server.", try_count)
def _get_missing_tensor_info(self, tensor_name, node_type): """ Get missing tensor infos. Args: tensor_name (str): The full name of Tensor. node_type (str): The type of the relative node. Returns: list, list of missing tensor basic information. """ step = self.cur_step missing_tensors_info = [] # check the current step value is missing if self._is_tensor_value_missing(tensor_name, step): missing_tensors_info.append( TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='')) log.debug("Add current step view cmd for %s", tensor_name) # check the previous step value is missing if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing( tensor_name, step - 1): missing_tensors_info.append( TensorBasicInfo(full_name=tensor_name, node_type=node_type, iter='prev')) log.debug("Add previous view cmd for %s", tensor_name) return missing_tensors_info
def _put_tensor_into_cache(self, tensor, step): """ Put tensor into cache. Args: tensor (OpTensor): The tensor value. step (int): The step of tensor. Returns: bool, the tensor has updated successfully. """ cache_tensor = self._tensors.get(tensor.name) if cache_tensor is None: cache_tensor = {} self._tensors[tensor.name] = cache_tensor old_tensor = cache_tensor.get(step) if old_tensor and not self._is_value_diff(old_tensor.value, tensor.value): log.debug("Tensor %s of step %s has no change. Ignore it.", tensor.name, step) return False cache_tensor[step] = tensor log.debug("Put updated tensor value for %s of step %s.", tensor.name, step) return True
def get_bfs_graph(self, node_name, bfs_order): """ Traverse the graph in order of breath-first search. Returns: list, including the leaf nodes arranged in BFS order. """ temp_list = deque() temp_list.append(node_name) while temp_list: node_name = temp_list.popleft() node = self._leaf_nodes.get(node_name) if not node: log.warning('Cannot find node %s in graph. Ignored.', node_name) continue bfs_order.append(node_name) if node.inputs: for name in node.inputs.keys(): if name not in temp_list and name not in bfs_order: temp_list.append(name) if node.outputs: for name in node.outputs.keys(): if name not in temp_list and name not in bfs_order: temp_list.append(name)
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_proto (TensorProto): The tensor proto. - tensor_contents (list[byte]): The list of tensor content values. Returns: bool, the tensor has updated successfully. """ tensor = self._deal_with_tensor(value) stats = None if value.get('stats', False) and tensor.status == TensorStatusEnum.CACHED.value: tensor.calculate_stats() stats = tensor.stats flag = self._put_tensors(tensor) new_tensor = self._tensors.get(tensor.name).get(tensor.step) new_tensor.stats = stats log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, tensor.step, flag) return flag
def _deal_with_run_cmd(self, event): """Deal with run cmd.""" run_cmd = event.run_cmd # receive step command if run_cmd.run_level == 'step': # receive pause cmd if not run_cmd.run_steps: log.debug("Pause training and wait for next command.") self._old_run_cmd.clear() return None # receive step cmd left_steps = run_cmd.run_steps - 1 event.run_cmd.run_steps = 1 if left_steps: self._old_run_cmd[ 'left_step_count'] = left_steps if left_steps > 0 else -1 elif run_cmd.node_name: self._old_run_cmd['node_name'] = run_cmd.node_name run_cmd.node_name = '' # clean watchpoint hit cache if run_cmd.run_level == RunLevel.RECHECK.value: self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).clean() log.debug("Receive RunCMD. Clean watchpoint hit cache.") return event
def get(self, filter_condition=None): """ Get full tensor value. Args: filter_condition (dict): Filter condition. - name (str): The full name of tensor. - node_type (str): The type of the node. - prev (bool): Whether to get previous tensor. Returns: dict, the tensor_value and whether need to send view_command. """ name = filter_condition.get('name') node_type = filter_condition.get('node_type') shape = filter_condition.get('shape') if filter_condition.get('prev'): step = self.prev_step else: step = self.cur_step tensor = self._get_tensor(name, node_type, step) if not tensor: log.error("No tensor named %s at the step %s", name, step) raise DebuggerParamValueError("No tensor named {}".format(name)) tensor_info = tensor.get_full_info(shape) self._update_has_prev_step_field(tensor_info, name, node_type, self.cur_step) res = {'tensor_value': tensor_info, 'view_cmd': False} if tensor.status == TensorStatusEnum.UNCACHED.value: self._add_hold_value_tensors(name, step) res['view_cmd'] = True return res
def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_construct = [] tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) tensor_names = [] step = metadata_stream.step for tensor in request_iterator: tensor_construct.append(tensor) if tensor.finished: update_flag = tensor_stream.put({ 'step': step, 'tensor_protos': tensor_construct }) if self._received_view_cmd.get( 'wait_for_tensor') and update_flag: # update_flag is used to avoid querying empty tensors again self._received_view_cmd['wait_for_tensor'] = False log.debug("Set wait for tensor flag to False.") tensor_construct = [] tensor_names.append(':'.join([tensor.node_name, tensor.slot])) continue reply = get_ack_reply() return reply
def create_session(self, session_type, train_job=None): """ Create the session by the train job info or session type if the session doesn't exist. Args: session_type (str): The session_type. train_job (str): The train job info. Returns: str, session id. """ with self._lock: if self._exiting: logger.info("System is exiting, will terminate the thread.") _thread.exit() if session_type == self.ONLINE_TYPE: if self.ONLINE_SESSION_ID not in self.sessions: logger.error( 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger ' 'when start Mindinsight server.') raise DebuggerOnlineSessionUnavailable() return self.ONLINE_SESSION_ID if train_job in self.train_jobs: return self.train_jobs.get(train_job) return self._create_offline_session(train_job)
def _pre_process(self, request): """Pre-process before dealing with command.""" # check if version is mismatch, if mismatch, send mismatch info to UI if self._status == ServerStatus.MISMATCH: log.warning("Version of Mindspore and Mindinsight re unmatched," "waiting for user to terminate the script.") metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) # put metadata into data queue metadata = metadata_stream.get(['state', 'debugger_version']) self._cache_store.put_data(metadata) return metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) is_new_step = metadata_stream.step < request.cur_step is_new_node = metadata_stream.full_name != request.cur_node # clean cache data at the beginning of new step or node has been changed. if is_new_step or is_new_node: self._cache_store.clean_data() self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( request.cur_step) if is_new_step: self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).clean() # receive graph at the beginning of the training if self._status == ServerStatus.RECEIVE_GRAPH: self._send_graph_flag(metadata_stream) # receive new metadata if is_new_step or is_new_node: self._update_metadata(metadata_stream, request) self._send_received_tensor_tag() self._send_watchpoint_hit_flag()
def search(self, filter_condition): """ Search for single node in graph. Args: filter_condition (dict): Filter condition. - name (str): The name pattern. - graph_name (str): The graph name. - watch_point_id (int): The id of watchpoint. Default: 0. - node_category (str): The node_category. Default: None Returns: dict, the searched nodes. """ log.info("receive search request with filter_condition: %s", filter_condition) # validate watchpoint id watch_point_id = filter_condition.pop('watch_point_id', 0) watchpoint_stream = self.cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoint_stream.validate_watchpoint_id(watch_point_id) # validate and update graph name graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_name = graph_stream.validate_graph_name( filter_condition.get('graph_name')) filter_condition['graph_name'] = graph_name # get searched graph graph = graph_stream.search_nodes(filter_condition) # add watched label to graph watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name) return graph
def get_tensor_history(self, node_name, graph_name=None, depth=0): """ Get the tensor history of a specified node. Args: node_name (str): The debug name of the node. graph_name (str): The graph_name. Default: None. depth (int): The number of layers the user wants to trace. Default is 0. Returns: dict, basic tensor history, only including tensor name and tensor type and node type. """ graph_name, node_name = self._parse_node_name(node_name, graph_name) graph = self._get_graph(graph_name=graph_name, node_name=node_name) # validate node type, scope node has no tensor history node_type = graph.get_node_type(node_name) if is_scope_type(node_type): log.error("Scope type node has no tensor history.") raise DebuggerParamValueError("Invalid leaf node name.") # get tensor history tensor_history, cur_outputs_nums = graph.get_tensor_history( node_name, depth) # add the tensor type for tensor history self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output', graph_name) self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input', graph_name) log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name) return {'tensor_history': tensor_history}
def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): """Retrieve the tensor value.""" log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) self.validate_tensor_param(name, detail) # Limit to query max two dimensions for tensor in table view. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name( name, graph_name) reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({ 'name': tensor_name, 'node_type': node_type, 'shape': parsed_shape, 'prev': prev }) reply['tensor_value']['name'] = name return reply
def put(self, value): """ Put value into graph cache. Called by grpc server. Args: value (GraphProto): The Graph proto message. """ log.info("Put graph into cache.") sorted_value_list = self._sort_graph(value) for graph_name, graph_value in sorted_value_list: self._graph_proto[graph_name] = graph_value # build sub graph graph = DebuggerGraph() graph.build_graph(graph_value) self._graph[graph_name] = graph self.bfs_order.extend(graph.get_bfs_order()) leaf_nodes = graph.leaf_nodes self._all_leaf_nodes.update(leaf_nodes) for _, node in leaf_nodes.items(): self.graph_node_map[node.full_name] = graph_name # build whole graph graph = DebuggerMultiGraph() graph.add_graph(self._graph) self._whole_graph = graph
def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): """ Get the graph of the next node according to node_name. Args: node_name (str): The name of current chosen leaf node. graph_name (str): The graph name. ascend (bool): If True, traverse the input nodes; If False, traverse the output nodes. Default is True. Returns: dict, the next node information. """ log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name, ascend) reply = {} graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_name = graph_stream.validate_graph_name(graph_name) next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) # no next node if next_node_name is None: return reply # add graph and tensor history for next node filter_condition = { 'name': next_node_name, 'graph_name': graph_name, 'single_node': True } search_graph = self._get_nodes_info(filter_condition) reply = {'name': next_node_name} reply.update(search_graph) return reply
def get(self, filter_condition=None): """ Get full tensor value. Args: filter_condition (dict): Filter condition. - name (str): The full name of tensor. - node_type (str): The type of the node. - prev (bool): Whether to get previous tensor. Returns: dict, the tensor_value. """ name = filter_condition.get('name') node_type = filter_condition.get('node_type') shape = filter_condition.get('shape') if filter_condition.get('prev'): step = self.prev_step else: step = self.cur_step tensor = self._get_tensor(name, node_type, step) if not tensor: log.error("No tensor named %s at the step %s", name, step) raise DebuggerParamValueError("No tensor named {}".format(name)) tensor_info = tensor.get_full_info(shape) self._update_has_prev_step_field(tensor_info, name, node_type) return {'tensor_value': tensor_info}
def put(self, value): """ Put value into event_cache. Args: value (dict): The event to be put into cache. """ if not isinstance(value, dict): log.error("Dict type required when put event message.") raise DebuggerParamValueError( "Dict type required when put event message.") with self._lock: log.debug( "Put the %d-th message into queue. \n %d requests is waiting.", self._next_idx, len(self._pending_requests)) cur_pos = self._next_idx # update next pos self._next_idx += 1 if self._next_idx >= self.max_limit: self._next_idx = 0 self._prev_flag = self._cur_flag self._cur_flag = str(uuid.uuid4()) # set next pos if not value.get('metadata'): value['metadata'] = {} value['metadata']['pos'] = self.next_pos self._event_cache[cur_pos] = value # feed the value for pending requests self.clean_pending_requests(value)
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_proto (TensorProto): The tensor proto. - tensor_contents (list[byte]): The list of tensor content values. Returns: bool, the tensor has updated successfully. """ tensor_proto = value.get('tensor_proto') tensor_proto.ClearField('tensor_content') step = value.get('step', 0) if tensor_proto.iter and step > 0: log.debug("Received previous tensor.") step -= 1 tensor_content = b''.join(value.get('tensor_contents')) tensor = OpTensor(tensor_proto, tensor_content, step) flag = self._put_tensor_into_cache(tensor, step) log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag) return flag
def _get_watch_names_by_search(self, search_nodes, target_node_name): """ Get watch names according to search results. Args: search_nodes (dict): Search result. The format is like {'nodes': [<Search Node>]}. The <Search Node> format is like {'name': <UI node name>, 'type': <node type>, 'nodes': [<Search Node>]} target_node_name (str): Node name for UI. Returns: set[str], collection of names. """ names = set() tmp_queue = Queue() tmp_queue.put(search_nodes) while not tmp_queue.empty(): cur_node = tmp_queue.get() for node in cur_node.get('nodes'): node_name = node.get('name') if not target_node_name.startswith(node_name) or is_cst_type( node.get('type')): continue if target_node_name == node_name: self._add_leaf_node_collection(node, names) return names tmp_queue.put(node) # the target node name is not in search nodes. log.debug("node %s is not in search nodes.") names.add(target_node_name) return names
def get_node_name_by_full_name(self, full_name): """Get node name by full names.""" inner_name = self._full_name_map_name.get(full_name, '') if not inner_name: log.warning("Node %s does not find the relative inner node name.", full_name) return inner_name