def get_tensor_value_by_shape(self, shape=None): """ Get tensor value by shape. Args: shape (tuple): The specified shape. Returns: Union[None, str, numpy.ndarray], the sub-tensor. """ if self._value is None: log.warning("%s has no value yet.", self.name) return None if shape is None or not isinstance(shape, tuple): log.info("Get the whole tensor value with shape is %s", shape) return self._value if len(shape) != len(self.shape): log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) raise DebuggerParamValueError("Invalid shape. Shape unmatched.") try: value = self._value[shape] except IndexError as err: log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) log.exception(err) raise DebuggerParamValueError("Invalid shape. Shape unmatched.") if isinstance(value, np.ndarray): if value.size > self.max_number_data_show_on_ui: value = "Too large to show." log.info( "The tensor size is %s, which is too large to show on UI.") else: value = np.asarray(value) return value
def control(self, params=None): """ Control the training process. Args: params (dict): The control params. - mode (str): Acceptable control command, including `continue`, `pause` and `terminate`. - level (str): The control granularity, `node` level or `step` level. Default: `step`. - steps (int): Specify the steps that training should run. Used when `level` is `step`. - name (str): Specify the name of the node. Used when `level` is `node`. Returns: dict, the response. """ log.info("Receive control request: %s.", params) mode = params.get('mode') metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if mode == 'continue': reply = self._continue(metadata_stream, params) elif mode in ['pause', 'terminate']: mode_mapping = {'pause': self._pause, 'terminate': self._terminate} reply = mode_mapping.get(mode)(metadata_stream) else: log.error("Invalid control mode %s", mode) raise DebuggerParamValueError("Invalid control mode.") return reply
def _retrieve_node(self, filter_condition): """ Retrieve node info. Args: filter_condition (dict): Filter condition. - name (str): The name of single node. - single_node (bool): If False, return the sub-layer of single node. If True, return the node list from root node to single node. Returns: dict, the node info. """ log.info("Retrieve node %s.", filter_condition) node_name = filter_condition.get('name') if node_name: # validate node name self.cache_store.get_stream_handler( Streams.GRAPH).get_node_type(node_name) filter_condition['single_node'] = bool( filter_condition.get('single_node')) reply = self._get_nodes_info(filter_condition) return reply
def retrieve_node_by_bfs(self, node_name, ascend=False): """ Get the graph of the next node according to node_name. Args: node_name (str): The name of current chosen leaf node. ascend (bool): If True, traverse the input nodes; If False, traverse the output nodes. Default is True. Returns: dict, the next node information. """ log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name, ascend) reply = {} graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) # no next node if next_node_name is None: return reply # add graph and tensor history for next node filter_condition = {'name': next_node_name, 'single_node': True} search_graph = self._get_nodes_info(filter_condition) reply = {'name': next_node_name} reply.update(search_graph) return reply
def retrieve(self, mode, filter_condition=None): """ Retrieve data according to mode and params. Args: mode (str): The type of info message. filter_condition (dict): The filter condition. Returns: dict, the retrieved data. """ log.info( "receive retrieve request for mode:%s\n, filter_condition: %s", mode, filter_condition) # validate watchpoint_id mode_mapping = { 'all': self._retrieve_all, 'node': self._retrieve_node, 'watchpoint': self._retrieve_watchpoint, 'watchpoint_hit': self._retrieve_watchpoint_hit } # validate param <mode> if mode not in mode_mapping.keys(): log.error( "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) raise DebuggerParamTypeError("Invalid mode.") filter_condition = {} if filter_condition is None else filter_condition reply = mode_mapping[mode](filter_condition) return reply
def get_node_name_by_full_name(self, full_name): """Get UI node name by full name.""" if self._graph: node_name = self._graph.get_node_name_by_full_name(full_name) else: node_name = '' log.info("No graph received yet.") return node_name
def search(self, name, watch_point_id): """Search for single node in graph.""" log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) graph = self.cache_store.get_stream_handler( Streams.GRAPH).search_nodes(name) self.cache_store.get_stream_handler( Streams.WATCHPOINT).set_watch_nodes(graph, watch_point_id) return graph
def _send_received_tensor_tag(self): """Send received_finish_tag.""" node_name = self._received_view_cmd.get('node_name') if not node_name or self._received_view_cmd.get('wait_for_tensor'): return metadata = self._cache_store.get_stream_handler(Streams.METADATA).get() ret = {'receive_tensor': {'node_name': node_name}} ret.update(metadata) self._cache_store.put_data(ret) self._received_view_cmd.clear() log.info("Send receive tensor flag for %s", node_name)
def _retrieve_all(self, filter_condition=None): """Retrieve metadata, root graph and watchpoint list.""" if filter_condition: log.error("No filter condition required for retrieve all request.") raise DebuggerParamTypeError("filter_condition should be empty.") result = {} self._watch_point_id = 0 self.cache_store.clean_data() log.info("Clean data queue cache when retrieve all request.") for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: sub_res = self.cache_store.get_stream_handler(stream).get() result.update(sub_res) return result
def SendGraph(self, request_iterator, context): """Send graph into DebuggerCache.""" log.info("Received graph.") serial_graph = b"" for chunk in request_iterator: serial_graph += chunk.buffer graph = GraphProto.FromString(serial_graph) log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) self._status = ServerStatus.RECEIVE_GRAPH reply = get_ack_reply() log.info("Send the reply for graph.") return reply
def retrieve_tensor_history(self, node_name): """ Retrieve tensor history for leaf node. Args: node_name (str): The name of leaf node. Returns: dict, the tensor history and metadata. """ log.info("Retrieve tensor history for node: %s.", node_name) self._validate_leaf_name(node_name) res = self._get_tensor_history(node_name) return res
def put(self, value): """ Put value into graph cache. Called by grpc server. Args: value (GraphProto): The Graph proto message. """ self._graph_proto = value log.info("Put graph into cache.") # build graph graph = DebuggerGraph() graph.build_graph(value) self._graph = graph self.bfs_order = self._graph.get_bfs_order()
def _update_metadata(self, metadata_stream, metadata_proto): """Update metadata.""" # reset view round and clean cache data if metadata_stream.step < metadata_proto.cur_step: self._cache_store.clean_data() self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( metadata_proto.cur_step) # put new metadata into cache metadata_stream.put(metadata_proto) cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( metadata_proto.cur_node) if metadata_proto.cur_node else '' metadata_stream.node_name = cur_node metadata = metadata_stream.get() self._cache_store.put_data(metadata) log.info("Put new metadata into data queue.")
def _get_next_command(self): """Get next command.""" self._pos, event = self._cache_store.get_command(self._pos) log.debug("Received event :%s", event) if event is None: return event if isinstance(event, dict): event = self._deal_with_view_cmd(event) elif event.HasField('run_cmd'): event = self._deal_with_run_cmd(event) elif event.HasField('exit'): self._cache_store.clean() log.info("Clean cache for exit cmd.") return event
def _wait_for_next_command(self): """ Wait for next command. Returns: EventReply, the command event. """ log.info("Start to wait for command.") self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' self._cache_store.put_data({'metadata': {'state': 'waiting'}}) event = None while event is None and self._status == ServerStatus.WAITING: log.debug("Wait for %s-th command", self._pos) event = self._get_next_command() return event
def _pre_process(self, request): """Send graph and metadata when WaitCMD first called.""" metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) if self._status == ServerStatus.RECEIVE_GRAPH: self._status = ServerStatus.WAITING metadata_stream.state = 'waiting' metadata = metadata_stream.get() self._cache_store.clean_command() res = self._cache_store.get_stream_handler(Streams.GRAPH).get() res.update(metadata) self._cache_store.put_data(res) log.info("Put graph into data queue.") if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node: # clean tensor cache and DataQueue at the beginning of each step self._update_metadata(metadata_stream, request)
def parse_shape(shape): """Parse shape.""" if shape is None: return shape if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')): log.error("Invalid shape. Received: %s", shape) raise DebuggerParamValueError("Invalid shape.") shape = shape.strip('[]') if shape.count(':') > 2: log.error("Invalid shape. At most two dimensions are specified.") raise DebuggerParamValueError("Invalid shape.") parsed_shape = tuple( str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple() log.info("Parsed shape: %s from %s", parsed_shape, shape) return parsed_shape
def SendTensors(self, request_iterator, context): """Send tensors into DebuggerCache.""" log.info("Received tensor.") tensor_construct = [] tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) tensor_names = [] step = metadata_stream.step for tensor in request_iterator: tensor_construct.append(tensor) if tensor.finished: if self._received_view_cmd.get('wait_for_tensor') and tensor.tensor_content: self._received_view_cmd['wait_for_tensor'] = False tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) tensor_construct = [] tensor_names.append(':'.join([tensor.node_name, tensor.slot])) continue reply = get_ack_reply() return reply
def _get_tensor_infos_of_node(cur_node, slot=None): """Get tensors info of specified node.""" tensors_info = [] if slot is None: slots = range(cur_node.output_nums) elif slot >= 0: slots = [slot] else: log.info("Skip get tensor info for %s:%s.", cur_node.name, slot) return tensors_info for num in slots: tensor_info = { 'name': cur_node.name + ':' + str(num), 'full_name': cur_node.full_name + ':' + str(num), 'node_type': cur_node.type } tensors_info.append(tensor_info) return tensors_info
def retrieve_tensor_value(self, name, detail, shape): """Retrieve the tensor value.""" log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) self.validate_tensor_param(name, detail) parsed_shape = self.parse_shape(shape) node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name( name) reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({ 'name': tensor_name, 'node_type': node_type, 'shape': parsed_shape }) reply['tensor_value']['name'] = name return reply
def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): """ Create watchpoint. Args: watch_condition (dict): The watch condition. - condition (str): Accept `INF` or `NAN`. - param (list[float]): Not defined yet. watch_nodes (list[str]): The list of node names. watch_point_id (int): The id of watchpoint. Returns: dict, the id of new watchpoint. """ log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) if metadata_stream.state != ServerStatus.WAITING.value: log.error( "Failed to create watchpoint as the MindSpore is not in waiting state." ) raise DebuggerCreateWatchPointError( "Failed to create watchpoint as the MindSpore is not in waiting state." ) if metadata_stream.backend == 'GPU' and watch_condition.get( 'condition') == 'OVERFLOW': log.error("GPU doesn't support OVERFLOW watch condition.") raise DebuggerParamValueError( "GPU doesn't support OVERFLOW watch condition.") watch_nodes = self._get_node_basic_infos(watch_nodes) watch_point_id = self.cache_store.get_stream_handler( Streams.WATCHPOINT).create_watchpoint(watch_condition, watch_nodes, watch_point_id) self._watch_point_id = 0 log.info("Create watchpoint %d", watch_point_id) return {'id': watch_point_id}
def put(self, value): """ Put value into tensor cache. Called by grpc server. Args: value (dict): The Tensor proto message. - step (int): The current step of tensor. - tensor_protos (list[TensorProto]): The tensor proto. """ tensor_protos = value.get('tensor_protos') merged_tensor = self._get_merged_tensor(tensor_protos) step = value.get('step', 0) if merged_tensor.iter and step > 0: log.debug("Received previous tensor.") step -= 1 tensor = OpTensor(merged_tensor, step) self._put_tensor_into_cache(tensor, step) log.info("Put tensor %s of step: %d, into cache", tensor.name, step)
def SendMetadata(self, request, context): """Send metadata into DebuggerCache.""" log.info("Received Metadata.") if self._status != ServerStatus.PENDING: log.info("Re-initialize cache store when new session comes.") self.init() client_ip = context.peer().split(':', 1)[-1] metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) if request.training_done: log.info("The training from %s has finished.", client_ip) else: metadata_stream.put(request) metadata_stream.client_ip = client_ip log.info("Put new metadata from %s into cache.", client_ip) # put metadata into data queue metadata = metadata_stream.get() self._cache_store.put_data(metadata) reply = get_ack_reply() log.info("Send the reply to %s.", client_ip) return reply
def start(self): """Start server.""" grpc_port = self.grpc_port if self.grpc_port else "50051" host = settings.HOST if hasattr(settings, 'HOST') else '[::]' hostname = "{}:{}".format(host, grpc_port) # initialize a grpc server grpc_server_manager = grpc.server( futures.ThreadPoolExecutor(max_workers=10)) grpc_server_base.add_EventListenerServicer_to_server( self.grpc_server, grpc_server_manager) grpc_server_manager.add_insecure_port(hostname) grpc_server_manager.start() my_server_thread = Thread( target=grpc_server_manager.wait_for_termination) # start grpc server my_server_thread.start() self.back_server = my_server_thread self.grpc_server_manager = grpc_server_manager # register stop server handler signal.signal(signal.SIGINT, self._stop_handler) log.info("Start grpc server %s", hostname)
def WaitCMD(self, request, context): """Wait for a command in DebuggerCache.""" # check if graph have already received. log.info("Received WaitCMD at %s-th step.", request.cur_step) if self._status == ServerStatus.PENDING: log.warning("No graph received before WaitCMD.") reply = get_ack_reply(1) return reply self._send_received_tensor_tag() # send graph if has not been sent before self._pre_process(request) # deal with old command reply = self._deal_with_old_command() if reply: log.info("Reply to WaitCMD with old command: %s", reply) return reply # continue multiple steps training if self._continue_steps: reply = get_ack_reply() reply.run_cmd.run_steps = 1 reply.run_cmd.run_level = 'step' self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() log.debug("Send RunCMD. Clean watchpoint hit.") # wait for command else: reply = self._wait_for_next_command() if reply is None: reply = get_ack_reply(1) log.warning("Failed to get command event.") else: log.info("Reply to WaitCMD: %s", reply) return reply
def SendWatchpointHits(self, request_iterator, context): """Send watchpoint hits info DebuggerCache.""" log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) self._continue_steps = 0 watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) for watchpoint_hit_proto in request_iterator: ui_node_name = graph_stream.get_node_name_by_full_name( watchpoint_hit_proto.tensor.node_name) log.debug("Receive watch point hit: %s", watchpoint_hit_proto) if not ui_node_name: log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name) continue watchpoint_hit = { 'tensor_proto': watchpoint_hit_proto.tensor, 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), 'node_name': ui_node_name } watchpoint_hit_stream.put(watchpoint_hit) watchpoint_hits_info = watchpoint_hit_stream.get() self._cache_store.put_data(watchpoint_hits_info) log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") reply = get_ack_reply() return reply
def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None): """ Update watchpoint. Args: watch_point_id (int): The id of watchpoint. watch_nodes (list[str]): The list of node names. mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. 1 for add nodes to watch nodes. name (str): The search name. Default: None. Returns: dict, empty response. """ if self.cache_store.get_stream_handler( Streams.METADATA).state != ServerStatus.WAITING.value: log.error( "Failed to update watchpoint as the MindSpore is not in waiting state." ) raise DebuggerUpdateWatchPointError( "Failed to update watchpoint as the MindSpore is not in waiting state." ) # validate if not watch_nodes or not watch_point_id: log.error("Invalid parameter for update watchpoint.") raise DebuggerParamValueError( "Invalid parameter for update watchpoint.") # update watch node if name is not None: watch_nodes = self._get_watch_nodes_by_search(watch_nodes) elif mode == 1: watch_nodes = self._get_node_basic_infos(watch_nodes) self.cache_store.get_stream_handler( Streams.WATCHPOINT).update_watchpoint(watch_point_id, watch_nodes, mode) self._watch_point_id = watch_point_id log.info("Update watchpoint with id: %d", watch_point_id) return {}
def delete_watchpoint(self, watch_point_id): """ Delete watchpoint. Args: watch_point_id (int): The id of watchpoint. Returns: dict, empty response. """ if self.cache_store.get_stream_handler( Streams.METADATA).state != ServerStatus.WAITING.value: log.error( "Failed to delete watchpoint as the MindSpore is not in waiting state." ) raise DebuggerDeleteWatchPointError( "Failed to delete watchpoint as the MindSpore is not in waiting state." ) self.cache_store.get_stream_handler( Streams.WATCHPOINT).delete_watchpoint(watch_point_id) self._watch_point_id = 0 log.info("Delete watchpoint with id: %d", watch_point_id) return {}
def get_bfs_order(self): """ Traverse the graph in order of breath-first search. Returns: list, including the leaf nodes arranged in BFS order. """ root = self.get_default_root() log.info('Randomly choose node %s as root to do BFS.', root.name) bfs_order = [] self.get_bfs_graph(root.name, bfs_order) length = len(self._leaf_nodes.keys()) # Find rest un-traversed nodes for node_name, _ in self._leaf_nodes.items(): if node_name not in bfs_order: self.get_bfs_graph(node_name, bfs_order) if len(bfs_order) != length: log.error("The length of bfs and leaf nodes are not equal.") msg = "Not all nodes are traversed!" raise DebuggerParamValueError(msg) return bfs_order
def stop(self): """Stop debugger server.""" self.grpc_server_manager.stop(grace=None) self.back_server.join() log.info("Stop debugger server.")