def _retrieve_watchpoint_hit(self, filter_condition): """ Retrieve watchpoint hit. Args: filter_condition (dict): Filter condition. - name (str): The name of single node. - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. Returns: dict, watch point list or relative graph. """ node_name = filter_condition.get('name') # get watchpoint hit list if node_name is None: reply = self.cache_store.get_stream_handler( Streams.WATCHPOINT_HIT).get() return reply self._validate_leaf_name(node_name) # get tensor history reply = self._get_tensor_history(node_name) log.debug("Get tensor history for watchpoint hit node.") # get single graph if filter_condition.get('single_node'): graph = self._get_nodes_info(filter_condition) reply.update(graph) log.debug("Get tensor history for watchpoint hit node.") return reply
def WaitCMD(self, request, context): """Wait for a command in DebuggerCache.""" # check if graph have already received. log.info("Received WaitCMD at %s-th step.", request.cur_step) if self._status == ServerStatus.PENDING: log.warning("No graph received before WaitCMD.") reply = get_ack_reply(1) return reply self._send_received_tensor_tag() # send graph if has not been sent before self._pre_process(request) # deal with old command reply = self._deal_with_old_command() if reply: log.info("Reply to WaitCMD with old command: %s", reply) return reply # continue multiple steps training if self._continue_steps: reply = get_ack_reply() reply.run_cmd.run_steps = 1 reply.run_cmd.run_level = 'step' self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() log.debug("Send RunCMD. Clean watchpoint hit.") # wait for command else: reply = self._wait_for_next_command() if reply is None: reply = get_ack_reply(1) log.warning("Failed to get command event.") else: log.info("Reply to WaitCMD: %s", reply) return reply
def get_tensor_history(self, node_name, depth=0): """ Get the tensor history of a specified node. Args: node_name (str): The debug name of the node. depth (int): The number of layers the user wants to trace. Default is 0. Returns: dict, basic tensor history, only including tensor name and tensor type and node type. """ self._graph_exists() if not self._graph.exist_node(node_name): raise DebuggerNodeNotInGraphError(node_name) tensor_history, cur_outputs_nums = self._graph.get_tensor_history( node_name, depth) # add the tensor type for tensor history self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output') self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input') log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name) return {'tensor_history': tensor_history}
def _retrieve_watchpoint(self, filter_condition): """ Retrieve watchpoint. Args: filter_condition (dict): Filter condition. - watch_point_id (int): The id of watchoint. If not given, return all watchpoints. - name (str): The name of single node. - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. Returns: dict, watch point list or relative graph. """ watchpoint_id = filter_condition.get('watch_point_id') watchpoint_stream = self.cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoint_stream.validate_watchpoint_id(watchpoint_id) self._watch_point_id = watchpoint_id if watchpoint_id else 0 if not watchpoint_id: reply = self.cache_store.get_stream_handler( Streams.WATCHPOINT).get() log.debug("Get condition of watchpoints.") else: reply = self._retrieve_node(filter_condition) log.debug("Get graph of %d-th watchpoint.", watchpoint_id) return reply
def add_node(self, node_name, node_type, full_name=''): """ Add watch node to watch node tree. Args: node_name (str): The node name. node_type (str): The node type. full_name (str): The full name of node. """ log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) scope_names = node_name.split('/', 1) if len(scope_names) == 1: if not self.get(node_name): self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) else: self.get(node_name).enable_watch_status() return scope_name, sub_names = scope_names sub_tree = self.get(scope_name) if not sub_tree: sub_tree = self.add(scope_name, watch_status=1) sub_tree.add_node(sub_names, node_type, full_name)
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) self._continue_steps = 0 watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: ui_node_name = graph_stream.get_node_name_by_full_name( watchpoint_hit_proto.tensor.node_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), 'node_name': ui_node_name } watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hits_info = watchpoint_hit_stream.get() self._cache_store.put_data(watchpoint_hits_info) log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") reply = get_ack_reply() return reply
def get(self, filter_condition=False): """ Get the watchpoints. Args: filter_condition (bool): If True, get all watchpoints without nodes. If False, get updated watchpoints in SetCMD proto format. Default: False. Returns: dict, the watchpoints. """ reply = [] if not filter_condition: # get watch condition list for _, watchpoint in self._watchpoints.items(): watchpoint_info = watchpoint.get_watch_condition_info() reply.append(watchpoint_info) else: # get updated watchpoint list for _, watchpoint in self._updated_watchpoints.items(): set_cmd = watchpoint.get_set_cmd() reply.append(set_cmd) reply.extend(self._deleted_watchpoints) log.debug("get the watch points with filter_condition:%s", filter_condition) return {'watch_points': reply}
def put(self, value): """ Put value into event_cache. Args: value (dict): The event to be put into cache. """ if not isinstance(value, dict): log.error("Dict type required when put event message.") raise DebuggerParamValueError("Dict type required when put event message.") with self._lock: log.debug("Put the %d-th message into queue. \n %d requests is waiting.", self._next_idx, len(self._pending_requests)) cur_pos = self._next_idx # update next pos self._next_idx += 1 if self._next_idx >= self.max_limit: self._next_idx = 0 self._prev_flag = self._cur_flag self._cur_flag = str(uuid.uuid4()) # set next pos if not value.get('metadata'): value['metadata'] = {} value['metadata']['pos'] = self.next_pos self._event_cache[cur_pos] = value # feed the value for pending requests self.clean_pending_requests(value)
def _continue(self, metadata_stream, params): """ Send RunCMD to MindSpore. Args: metadata_stream (MetadataHandler): The metadata_handler params (dict): The control params. """ if metadata_stream.state != ServerStatus.WAITING.value: log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) raise DebuggerContinueError( "MindSpore is not ready to run or is running currently.") metadata_stream.state = ServerStatus.RUNNING.value current_state = ServerStatus.RUNNING.value try: event = self._construct_run_event(params) self._send_watchpoints() self.cache_store.put_command(event) except MindInsightException as err: log.error("Failed to send run event.") log.exception(err) current_state = ServerStatus.WAITING.value metadata_stream.state = current_state raise DebuggerContinueError("Failed to send run command.") else: log.debug("Send the RunCMD to command queue.") return {'metadata': {'state': current_state}}
def _deal_with_old_command(self): """Deal with old command.""" event = None while self._cache_store.has_command(self._pos) and event is None: event = self._get_next_command() log.debug("Deal with old %s-th command:\n%s.", self._pos, event) return event
def put_data(self, value): """ Set updated data to data stream. Args: value (dict): The updated data. """ log.debug("Set <%d> bytes data", sys.getsizeof(value)) return self._put(Streams.DATA, value)
def put_command(self, cmd): """ Set command to command stream. Args: cmd (EventReply): The command EventReply. """ log.debug("Set command %s", cmd) return self._put(Streams.COMMAND, {'cmd': cmd})
def clean(self): """Clean event cache.""" with self._lock: self._prev_flag = str(uuid.uuid4()) self._cur_flag = str(uuid.uuid4()) self._next_idx = 0 self._event_cache = [None] * self.max_limit value = {'metadata': {'pos': '0'}} self.clean_pending_requests(value) log.debug("Clean event cache. %d request is waiting.", len(self._pending_requests))
def _deal_with_view_cmd(self, event): """Deal with view cmd.""" view_cmd = event.get('view_cmd') node_name = event.get('node_name') log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name) if not (view_cmd and node_name): log.warning("Invaid view command. Ignore it.") return None self._received_view_cmd['node_name'] = node_name self._received_view_cmd['wait_for_tensor'] = True return view_cmd
def put(self, value): """ Put value into metadata cache. Called by grpc server. Args: value (MetadataProto): The Metadata proto message. """ self._device_name = value.device_name.split(':')[0] self._step = value.cur_step self._cur_full_name = value.cur_node self._backend = value.backend if value.backend else "Ascend" log.debug("Put metadata into cache at the %d-th step.", self._step)
def put(self, value): """ Put Watchpoint into watchpoint handler. Args: value (Watchpoint): The name of nodes that have been chosen. """ new_id = value.watchpoint_id self._watchpoints[new_id] = value self._updated_watchpoints[new_id] = value self._latest_id = new_id log.debug("Put watchpoint %d into cache.", new_id)
def _terminate(self, metadata_stream): """ Terminate the training. Args: metadata_stream (MetadataHandler): The metadata stream handler. """ metadata_stream.state = 'pending' event = get_ack_reply() event.exit = True self.cache_store.put_command(event) log.debug("Send the ExitCMD.") return {'metadata': {'state': 'pending'}}
def _send_watchpoints(self): """Set watchpoints.""" watchpoint_stream = self.cache_store.get_stream_handler( Streams.WATCHPOINT) watchpoints = watchpoint_stream.get( filter_condition=True).get('watch_points') if watchpoints: for watchpoint in watchpoints: event = get_ack_reply() event.set_cmd.CopyFrom(watchpoint) self.cache_store.put_command(event) watchpoint_stream.sync_set_cmd() log.debug("Send SetCMD to MindSpore. %s", event)
def set_watch_nodes(self, graph, watch_point_id): """ set watch nodes for graph. Args: graph (dict): The graph with list of nodes. watch_point_id (int): The id of watchpoint. """ if not (watch_point_id and graph): return self.validate_watchpoint_id(watch_point_id) log.debug("add watch flags") watchpoint = self._watchpoints.get(watch_point_id) self._set_watch_status_recursively(graph, watchpoint)
def SendGraph(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received graph.") serial_graph = b"" for chunk in request_iterator: serial_graph += chunk.buffer graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) self._status = ServerStatus.RECEIVE_GRAPH reply = get_ack_reply() log.info("Send the reply for graph.") return reply
def _wait_for_next_command(self): """ Wait for next command. Returns: EventReply, the command event. """ log.info("Start to wait for command.") self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' self._cache_store.put_data({'metadata': {'state': 'waiting'}}) event = None while event is None and self._status == ServerStatus.WAITING: log.debug("Wait for %s-th command", self._pos) event = self._get_next_command() return event
def _get_next_command(self): """Get next command.""" self._pos, event = self._cache_store.get_command(self._pos) log.debug("Received event :%s", event) if event is None: return event if isinstance(event, dict): event = self._deal_with_view_cmd(event) elif event.HasField('run_cmd'): event = self._deal_with_run_cmd(event) elif event.HasField('exit'): self._cache_store.clean() log.info("Clean cache for exit cmd.") return event
def _pause(self, metadata_stream): """ Pause the training. Args: metadata_stream (MetadataHandler): The metadata stream handler. """ if metadata_stream.state != ServerStatus.RUNNING.value: log.error("The MindSpore is not running.") raise DebuggerPauseError("The MindSpore is not running.") metadata_stream.state = 'waiting' event = get_ack_reply() event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) self.cache_store.put_command(event) log.debug("Send the Pause command") return {'metadata': {'state': 'waiting'}}
def _get_watch_nodes_by_search(self, watch_nodes): """Get watched leaf nodes by search name.""" watched_leaf_nodes = [] graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) for search_name in watch_nodes: search_nodes = graph_stream.get_searched_node_list() search_node_names = [ NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) for node in search_nodes if node.name.startswith(search_name) ] watched_leaf_nodes.extend(search_node_names) log.debug("Update nodes: %s", watched_leaf_nodes) return watched_leaf_nodes
def _wait_for_event(self, cur_id, cur_queue, pos): """Wait for the pos-th event.""" try: # set the timeout to 25 seconds which is less the the timeout limit from UI event = cur_queue.get(timeout=25) except Empty: event = None if event is None: with self._lock: if self._pending_requests.get(cur_id): self._pending_requests.pop(cur_id) log.debug("Clean timeout request. Left pending requests: %d", len(self._pending_requests)) event = {'metadata': {'pos': pos}} return event
def _deal_with_run_cmd(self, event): """Deal with run cmd.""" run_cmd = event.run_cmd # receive step command if run_cmd.run_level == 'step': # receive pause cmd if run_cmd.run_steps == 0: log.debug("Pause training and wait for next command.") self._continue_steps = 0 return None # receive step cmd self._continue_steps = run_cmd.run_steps - 1 event.run_cmd.run_steps = 1 self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() log.debug("Receive RunCMD. Clean watchpoint hit cache.") return event
def delete_watchpoint(self, watch_point_id): """ Delete watchpoint. Args: watch_point_id (int): The id of watchpoint. Returns: dict, empty response. """ self.validate_watchpoint_id(watch_point_id) self._watchpoints.pop(watch_point_id) set_cmd = SetCMD() set_cmd.id = watch_point_id set_cmd.delete = True self._deleted_watchpoints.append(set_cmd) log.debug("Delete watchpoint %d in cache.", watch_point_id)
def get(self, filter_condition=None): """ Get watchpoint hit list. Args: filter_condition (str): Get the watchpoint hit according to specifiled node name. If not given, get all watchpoint hits. Default: None. Returns: dict, the watchpoint hit list. """ if filter_condition is None: log.debug("Get all watchpoint hit list.") reply = self.get_watchpoint_hits() else: log.debug("Get the watchpoint for node: <%s>.", filter_condition) reply = self._hits.get(filter_condition) return reply
def _add_tensor_value_for_tensor_history(self, tensor_history, node_name): """ Add tensor value for_tensor_history and send ViewCMD if tensor value missed. Args: tensor_history (list[dict]): A list of tensor info, including name and type. node_name (str): The UI node name. Returns: dict, the tensor info. """ tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) missed_tensors = tensor_stream.update_tensor_history(tensor_history) if missed_tensors: view_cmd = create_view_event_from_tensor_history(missed_tensors) self.cache_store.put_command({ 'view_cmd': view_cmd, 'node_name': node_name }) log.debug("Send view cmd.")
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_protos (list[TensorProto]): The tensor proto. """ tensor_protos = value.get('tensor_protos') merged_tensor = self._get_merged_tensor(tensor_protos) step = value.get('step', 0) if merged_tensor.iter and step > 0: log.debug("Received previous tensor.") step -= 1 tensor = OpTensor(merged_tensor, step) self._put_tensor_into_cache(tensor, step) log.info("Put tensor %s of step: %d, into cache", tensor.name, step)