def _retrieve_all(self, filter_condition=None): """Retrieve metadata, root graph and watchpoint list.""" if filter_condition: log.error("No filter condition required for retrieve all request.") raise DebuggerParamTypeError("filter_condition should be empty.") self.cache_store.clean_data() log.info("Clean data queue cache when retrieve all request.") result = {} for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]: sub_res = self.cache_store.get_stream_handler(stream).get() result.update(sub_res) devices = result['devices'] if not devices: graph = result['graph'] metadata = result['metadata'] device = { 'rank_id': 0, 'server_ip': metadata.get('ip', 'localhost'), 'device_id': metadata.get('device_name', ''), 'graph_names': graph.get('graph_names', []) } devices.append(device) sub_res = self._hide_parameters_for_ui() result.update(sub_res) return result
def download(self, name, prev, graph_name=None, rank_id=0): """ Download the tensor value. Args: name (str): Node name shown in UI. prev (bool): The previous step or current step. graph_name (Union[str, None]): The graph name, default is: None. rank_id (int): The id of rank. Default: 0. Returns: str, the file path. str, the file name. """ if not isinstance(name, str) or ':' not in name: log.error("Invalid tensor name. Received: %s", name) raise DebuggerParamValueError("Invalid tensor name.") _, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name( name, graph_name, rank_id) log.info("Download the tensor value: name: %s", tensor_name) tensor_stream = self.cache_store.get_stream_handler( Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id) step = tensor_stream.cur_step if prev: step -= 1 tensor_info = { "tensor_name": tensor_name, "graph_name": graph_name, "step": step, "rank_id": rank_id } return tensor_stream.download_mgr.get(**tensor_info)
def stop(self): """Stop debugger server.""" self._running.wait() log.info("Start to stop offline debugger server.") self._running.clear() self.send() self.join()
def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_construct = [] tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) metadata_stream = self._cache_store.get_stream_handler( Streams.METADATA) tensor_names = [] step = metadata_stream.step for tensor in request_iterator: tensor_construct.append(tensor) if tensor.finished: update_flag = tensor_stream.put({ 'step': step, 'tensor_protos': tensor_construct }) if self._received_view_cmd.get( 'wait_for_tensor') and update_flag: # update_flag is used to avoid querying empty tensors again self._received_view_cmd['wait_for_tensor'] = False log.debug("Set wait for tensor flag to False.") tensor_construct = [] tensor_names.append(':'.join([tensor.node_name, tensor.slot])) continue reply = get_ack_reply() return reply
def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False): """ Get the graph of the next node according to node_name. Args: node_name (str): The name of current chosen leaf node. graph_name (str): The graph name. ascend (bool): If True, traverse the input nodes; If False, traverse the output nodes. Default is True. Returns: dict, the next node information. """ log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name, ascend) reply = {} graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_name = graph_stream.validate_graph_name(graph_name) next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) # no next node if next_node_name is None: return reply # add graph and tensor history for next node filter_condition = { 'name': next_node_name, 'graph_name': graph_name, 'single_node': True } search_graph = self._get_nodes_info(filter_condition) reply = {'name': next_node_name} reply.update(search_graph) return reply
def put(self, value): """ Put value into graph cache. Called by grpc server. Args: value (GraphProto): The Graph proto message. """ log.info("Put graph into cache.") sorted_value_list = self._sort_graph(value) for graph_name, graph_value in sorted_value_list: self._graph_proto[graph_name] = graph_value # build sub graph graph = DebuggerGraph() graph.build_graph(graph_value) self._graph[graph_name] = graph self.bfs_order.extend(graph.get_bfs_order()) leaf_nodes = graph.leaf_nodes self._all_leaf_nodes.update(leaf_nodes) for _, node in leaf_nodes.items(): self.graph_node_map[node.full_name] = graph_name # build whole graph graph = DebuggerMultiGraph() graph.add_graph(self._graph) self._whole_graph = graph
def load_device_info(self): """Load device_info from dump path.""" device_info = {} if not self._rank_dirs: log.info("No rank directory found under dump path.") return device_info rank_dir = self._rank_dirs[0].path hccl_json = self._load_json_file(rank_dir / '.dump_metadata' / 'hccl.json') if hccl_json.get('server_list'): device_info = { 'device_target': self._device_target, 'server_list': hccl_json['server_list'] } else: log.info( "Server List info is missing. Set device id same with rank id as default." ) devices = [] for rank_dir in self._rank_dirs: rank_id = rank_dir.rank_id devices.append({ 'device_id': str(rank_id), 'rank_id': str(rank_id) }) device_info = { 'device_target': self._device_target, 'server_list': [{ 'server_id': 'localhost', 'device': devices }] } return device_info
def _add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit): """Add hit node info.""" graph_stream = multi_card_graph_streams.get_graph_handler_by_rank_id( rank_id) node_full_name = hit['name'] graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) return ui_node_name = graph_stream.get_node_name_by_full_name( node_full_name, graph_name) log.debug("Receive watch point hit: %s:%s", node_full_name, hit['slot']) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) return watchpoint_hit.update({ 'tensor_proto': TensorProto(node_name=node_full_name, slot=str(hit['slot'])), 'node_name': ui_node_name, 'graph_name': graph_name })
def _deal_with_set_cmd(self, event): """ Deal with set cmd. Args: event (EventReply): User command event including set_cmd. """ set_cmd = event.set_cmd set_cmd_id = set_cmd.id delete = set_cmd.delete if not delete: log.info("Add watchpoint by using dbg_server.") watch_condition = set_cmd.watch_condition param_list = [] for param in watch_condition.params: param_list.append( self._dbg_services_module.Parameter( param.name, param.disabled, param.value)) watch_nodes = set_cmd.watch_nodes check_nodes = self._get_check_nodes(watch_nodes) log.debug("Watchpoint %s, condition: %s, watch nodes: %s", set_cmd_id, watch_condition.condition, check_nodes) self._dbg_service.add_watchpoint(set_cmd_id, watch_condition.condition, check_nodes, param_list) else: log.info("Remove watchpoint by using dbg_server.") self._dbg_service.remove_watchpoint(set_cmd_id)
def _save_watchpoint_hits(self, hits): """Save watchpoint hits.""" multi_card_hit_streams = self._cache_store.get_stream_handler( Streams.WATCHPOINT_HIT) multi_card_graph_streams = self._cache_store.get_stream_handler( Streams.GRAPH) watchpoint_stream = self._cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoint_hits = defaultdict(list) for hit in hits: log.info( "Received hit\n: " "name:%s, slot:%s, condition:%s, " "watchpoint_id:%s" "error_code:%s, rank_id:%s", hit['name'], hit['slot'], hit['condition'], hit['watchpoint_id'], hit['error_code'], hit['rank_id']) rank_id = hit['rank_id'] watchpoint_hit = {} self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams, rank_id, hit) if not watchpoint_hit: continue self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream, hit) watchpoint_hit['error_code'] = hit['error_code'] watchpoint_hits[rank_id].append(watchpoint_hit) # save hit info into cache multi_card_hit_streams.put(watchpoint_hits) self._cache_store.put_data({'receive_watchpoint_hits': True}) log.debug("Send the watchpoint hits to DataQueue.")
def retrieve_tensor_value(self, name, detail, shape, graph_name=None, prev=False): """Retrieve the tensor value.""" log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) self.validate_tensor_param(name, detail) # Limit to query max two dimensions for tensor in table view. parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR) node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name( name, graph_name) reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({ 'name': tensor_name, 'node_type': node_type, 'shape': parsed_shape, 'prev': prev }) reply['tensor_value']['name'] = name return reply
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_proto (TensorProto): The tensor proto. - tensor_contents (list[byte]): The list of tensor content values. Returns: bool, the tensor has updated successfully. """ tensor_proto = value.get('tensor_proto') tensor_proto.ClearField('tensor_content') step = value.get('step', 0) if tensor_proto.iter and step > 0: log.debug("Received previous tensor.") step -= 1 tensor_content = b''.join(value.get('tensor_contents')) tensor = OpTensor(tensor_proto, tensor_content, step) flag = self._put_tensor_into_cache(tensor, step) log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag) return flag
def retrieve(self, mode, filter_condition=None): """ Retrieve data according to mode and params. Args: mode (str): The type of info message. filter_condition (dict): The filter condition. Returns: dict, the retrieved data. """ log.info( "receive retrieve request for mode:%s\n, filter_condition: %s", mode, filter_condition) mode_mapping = { 'all': self._retrieve_all, 'node': self._retrieve_node, 'watchpoint': self._retrieve_watchpoint, } # validate param <mode> if mode not in mode_mapping.keys(): log.error( "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " "'watchpoint_hit'], but got %s.", mode_mapping) raise DebuggerParamValueError("Invalid mode.") # validate backend status metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.state == ServerStatus.PENDING.value: log.info("The backend is in pending status.") return metadata_stream.get() filter_condition = {} if filter_condition is None else filter_condition reply = mode_mapping[mode](filter_condition) return reply
def stop(self): """Stop debugger server.""" log.info("Send terminate info to client.") self.control({'mode': 'terminate'}) self.grpc_server_manager.stop(grace=None) self.back_server.join() log.info("Stop debugger server.")
def search(self, filter_condition): """ Search for single node in graph. Args: filter_condition (dict): Filter condition. - name (str): The name pattern. - graph_name (str): The graph name. - watch_point_id (int): The id of watchpoint. Default: 0. - node_category (str): The node_category. Default: None Returns: dict, the searched nodes. """ log.info("receive search request with filter_condition: %s", filter_condition) # validate watchpoint id watch_point_id = filter_condition.pop('watch_point_id', 0) watchpoint_stream = self.cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoint_stream.validate_watchpoint_id(watch_point_id) # validate and update graph name graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) graph_name = graph_stream.validate_graph_name( filter_condition.get('graph_name')) filter_condition['graph_name'] = graph_name # get searched graph graph = graph_stream.search_nodes(filter_condition) # add watched label to graph watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id, graph_name) return graph
def create_session(self, session_type, train_job=None): """ Create the session by the train job info or session type if the session doesn't exist. Args: session_type (str): The session_type. train_job (str): The train job info. Returns: str, session id. """ with self._lock: if self._exiting: logger.info("System is exiting, will terminate the thread.") _thread.exit() if session_type == self.ONLINE_TYPE: if self.ONLINE_SESSION_ID not in self.sessions: logger.error( 'Online session is unavailable, set --enable-debugger as true/1 to enable debugger ' 'when start Mindinsight server.') raise DebuggerOnlineSessionUnavailable() return self.ONLINE_SESSION_ID if train_job in self.train_jobs: return self.train_jobs.get(train_job) return self._create_offline_session(train_job)
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_proto (TensorProto): The tensor proto. - tensor_contents (list[byte]): The list of tensor content values. Returns: bool, the tensor has updated successfully. """ tensor = self._deal_with_tensor(value) stats = None if value.get('stats', False) and tensor.status == TensorStatusEnum.CACHED.value: tensor.calculate_stats() stats = tensor.stats flag = self._put_tensors(tensor) new_tensor = self._tensors.get(tensor.name).get(tensor.step) new_tensor.stats = stats log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, tensor.step, flag) return flag
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left run cmd %s change to emtpy.", self._old_run_cmd) self._old_run_cmd.clear() if self._cache_store.get_stream_handler( Streams.METADATA).state == ServerStatus.RUNNING.value: # if the client session is running a script, all the cached command should be cleared # when received watchpoint_hits. self._cache_store.clean_command() # save the watchpoint_hits data watchpoint_hits = [] watchpoint_stream = self._cache_store.get_stream_handler( Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: node_full_name = watchpoint_hit_proto.tensor.node_name graph_name = graph_stream.get_graph_id_by_full_name(node_full_name) if not graph_name: log.warning("Cannot find node %s in graph. Skip it.", node_full_name) continue ui_node_name = graph_stream.get_node_name_by_full_name( node_full_name, graph_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", node_full_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': copy.deepcopy( watchpoint_stream.get_watchpoint_by_id( watchpoint_hit_proto.id)), 'node_name': ui_node_name, 'graph_name': graph_name } hit_params = {} for param in watchpoint_hit_proto.watch_condition.params: if param.actual_value: hit_params[param.name] = param.actual_value for i, param in enumerate( watchpoint_hit['watchpoint'].condition['params']): name = param['name'] if name in hit_params.keys(): watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = hit_params[name] else: watchpoint_hit['watchpoint'].condition['params'][i][ 'actual_value'] = None if watchpoint_hit_proto.error_code: watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code watchpoint_hits.append(watchpoint_hit) self._received_hit = watchpoint_hits reply = get_ack_reply() return reply
def run(self): """Start the debugger offline server.""" log.info("Initialize Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._offline_server_manager.initialize() log.info("Start Offline Debugger Server for dbg_dir: %s", self._context.dbg_dir) self._running.set() try_count = 0 while self._running.is_set( ) and try_count < self._MAX_TRY_EXCEPT_COUNT: try: self._offline_server_manager.wait_for_termination() if not self._offline_server_manager.is_runnable(): break except MindInsightException as err: log.exception(err) log.warning( "Error happens during listening on user commands. Restart listening again." ) finally: try_count += 1 # protect server from too much failure commands. if try_count == self._MAX_TRY_EXCEPT_COUNT: self._cache_store.clean() metadata = self._cache_store.get_stream_handler( Streams.METADATA).get() self._cache_store.put_data(metadata) log.warning("Exception exceed %d times, stop server.", try_count)
def SendGraph(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received graph.") reply = get_ack_reply() if self._status == ServerStatus.MISMATCH: log.info( "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service." ) return reply serial_graph = b"" for chunk in request_iterator: serial_graph += chunk.buffer graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph %s. Receive %s nodes", graph.name, len(graph.node)) graph_dict = {graph.name: graph} self._cache_store.get_stream_handler(Streams.GRAPH).put( {0: graph_dict}) self._cache_store.get_stream_handler(Streams.GRAPH).parse_stack_infos() self._cache_store.get_stream_handler( Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals( graph.const_vals) self._cache_store.get_stream_handler( Streams.METADATA).graph_name = graph.name self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply
def SendMultiGraphs(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received multi_graphs.") reply = get_ack_reply() if self._status == ServerStatus.MISMATCH: log.info( "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service." ) return reply serial_graph = b"" graph_dict = {} for chunk in request_iterator: serial_graph += chunk.buffer if chunk.finished: sub_graph = GraphProto.FromString(serial_graph) graph_dict[sub_graph.name] = sub_graph log.debug("Deserialize the graph %s. Receive %s nodes", sub_graph.name, len(sub_graph.node)) serial_graph = b"" self._cache_store.get_stream_handler( Streams.TENSOR).put_const_vals(sub_graph.const_vals) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict) self._record_parameter_names() self._status = ServerStatus.RECEIVE_GRAPH log.debug("Send the reply for graph.") return reply
def search_watchpoint_hits(self, group_condition): """ Retrieve watchpoint hit. Args: group_condition (dict): Filter condition. - limit (int): The limit of each page. - offset (int): The offset of current page. - node_name (str): The retrieved node name. - graph_name (str): The retrieved graph name. Returns: dict, watch point list or relative graph. """ if not isinstance(group_condition, dict): log.error( "Group condition for watchpoint-hits request should be a dict") raise DebuggerParamTypeError( "Group condition for watchpoint-hits request should be a dict") metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.state == ServerStatus.PENDING.value: log.info("The backend is in pending status.") return metadata_stream.get() reply = self.cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).group_by(group_condition) reply['outdated'] = self.cache_store.get_stream_handler( Streams.WATCHPOINT).is_recheckable() return reply
def create_watchpoint(self, params): """ Create watchpoint. Args: - watch_condition (dict): The watch condition. The format is like: { "id": "tensor_too_large", "params": [ { "name": "abs_mean_gt", "value": 1.1 } ] } - id (str): Id of condition. - params (list[dict]): The list of param for this condition. - watch_nodes (list[str]): The list of node names. - watch_point_id (int): The id of watchpoint. - search_pattern (dict): The search pattern. - graph_name (str): The relative graph_name of the watched node. Returns: dict, the id of new watchpoint and metadata info. """ watch_condition = params.get('watch_condition') log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) metadata_stream = self._metadata_stream if metadata_stream.state != ServerStatus.WAITING.value: log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") raise DebuggerCreateWatchPointError( "Failed to create watchpoint as the MindSpore is not in waiting state.") self._validate_watch_condition(watch_condition) watch_nodes = self._get_watch_node_with_basic_info( node_names=params.get('watch_nodes'), search_pattern=params.get('search_pattern'), graph_name=params.get('graph_name')) validate_watch_condition(self._condition_mgr, watch_condition) condition_id = watch_condition.get('id') condition = self._condition_mgr.get_condition(condition_id) condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step) if not condition.is_available(condition_context): log.error("Failed to create watchpoint as the condition is not available.") raise DebuggerConditionUnavailableError( "Failed to create watchpoint as the condition is not available.") watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy() watchpoint_stream = self._watchpoint_stream watch_point_id = watchpoint_stream.create_watchpoint( self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) log.info("Create watchpoint %d", watch_point_id) metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() res = metadata_stream.get(['state', 'enable_recheck']) res['id'] = watch_point_id return res
def _read_tensor_work(self, tensor_infos, res): """The check WatchPoint function work in another process.""" log.info("Start read tensor process.") tensor_data_res = self._dbg_service.read_tensors(tensor_infos) for tensor_data in tensor_data_res: tensor_data_dict = convert_tensor_data(tensor_data) res.append(tensor_data_dict) log.info("Reading tensor process is finished.")
def exit(self): """Called when the gunicorn worker process is exiting.""" with self._lock: logger.info("Start to exit sessions.") self._exiting = True for session in self.sessions.values(): session.stop() logger.info("Sessions exited.")
def _check_watchpoint_work(self, hits, step): """The check WatchPoint function work in another process.""" log.info("Start checking WatchPointHit process.") res = self._dbg_service.check_watchpoints(step) for watchpoint_hit in res: hit_dict = convert_watchpointhit(watchpoint_hit) hits.append(hit_dict) log.info("Checking WatchPointHit process is finished.")
def send_latest_metadata(self, *args, **kwargs): try: return func(self, *args, **kwargs) except MindInsightException as err: metadata = self.cache_store.get_stream_handler( Streams.METADATA).get() self.cache_store.put_data(metadata) log.info("Put latest metadata into data-queue.") raise err
def stop(self): """Stop server.""" self._is_running_flag = False self._command_listener.stop() self._cache_store.clean() event = get_ack_reply() event.exit = True self._cache_store.put_command(event) log.info("Stop debugger offline manager.")
def _deal_with_run_cmd(self, event): """Deal with run cmd.""" run_cmd = event.run_cmd parsed_run_cmd = self._get_parsed_run_cmd(run_cmd) if parsed_run_cmd.run_steps > 0: self._execute_one_step() elif run_cmd.run_level == RunLevel.RECHECK.value: log.info("Deal with recheck command.") self._check_watchpoint(self._metadata_stream.step)
def stop(self): """Stop offline debugger server.""" if not self.is_alive(): log.info("Offline debugger has already stop") return log.debug("Start to wait for thread started.") self._running.wait() log.info("Start to stop offline debugger server.") self._running.clear() self._offline_server_manager.stop() self.join()