Пример #1
0
    def get_tensor_value_by_shape(self, shape=None):
        """
        Get tensor value by shape.

        Args:
            shape (tuple): The specified shape.

        Returns:
            Union[None, str, numpy.ndarray], the sub-tensor.
        """
        if self._value is None:
            log.warning("%s has no value yet.", self.name)
            return None
        if shape is None or not isinstance(shape, tuple):
            log.info("Get the whole tensor value with shape is %s", shape)
            return self._value
        if len(shape) != len(self.shape):
            log.error("Invalid shape. Received: %s, tensor shape: %s", shape,
                      self.shape)
            raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
        try:
            value = self._value[shape]
        except IndexError as err:
            log.error("Invalid shape. Received: %s, tensor shape: %s", shape,
                      self.shape)
            log.exception(err)
            raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
        if isinstance(value, np.ndarray):
            if value.size > self.max_number_data_show_on_ui:
                value = "Too large to show."
                log.info(
                    "The tensor size is %s, which is too large to show on UI.")
        else:
            value = np.asarray(value)
        return value
Пример #2
0
    def control(self, params=None):
        """
        Control the training process.

        Args:
            params (dict): The control params.

                - mode (str): Acceptable control command, including `continue`,
                    `pause` and `terminate`.

                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.

                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.

                - name (str): Specify the name of the node. Used when `level` is `node`.

        Returns:
            dict, the response.
        """
        log.info("Receive control request: %s.", params)
        mode = params.get('mode')
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if mode == 'continue':
            reply = self._continue(metadata_stream, params)
        elif mode in ['pause', 'terminate']:
            mode_mapping = {'pause': self._pause, 'terminate': self._terminate}
            reply = mode_mapping.get(mode)(metadata_stream)
        else:
            log.error("Invalid control mode %s", mode)
            raise DebuggerParamValueError("Invalid control mode.")

        return reply
Пример #3
0
    def _retrieve_node(self, filter_condition):
        """
        Retrieve node info.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, the node info.
        """
        log.info("Retrieve node %s.", filter_condition)
        node_name = filter_condition.get('name')
        if node_name:
            # validate node name
            self.cache_store.get_stream_handler(
                Streams.GRAPH).get_node_type(node_name)
        filter_condition['single_node'] = bool(
            filter_condition.get('single_node'))
        reply = self._get_nodes_info(filter_condition)
        return reply
Пример #4
0
    def retrieve_node_by_bfs(self, node_name, ascend=False):
        """
        Get the graph of the next node according to node_name.

        Args:
            node_name (str): The name of current chosen leaf node.
            ascend (bool): If True, traverse the input nodes;
                If False, traverse the output nodes. Default is True.

        Returns:
            dict, the next node information.
        """
        log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name,
                 ascend)
        reply = {}
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
        # no next node
        if next_node_name is None:
            return reply
        # add graph and tensor history for next node
        filter_condition = {'name': next_node_name, 'single_node': True}
        search_graph = self._get_nodes_info(filter_condition)
        reply = {'name': next_node_name}
        reply.update(search_graph)

        return reply
Пример #5
0
    def retrieve(self, mode, filter_condition=None):
        """
        Retrieve data according to mode and params.

        Args:
            mode (str): The type of info message.
            filter_condition (dict): The filter condition.

        Returns:
            dict, the retrieved data.
        """
        log.info(
            "receive retrieve request for mode:%s\n, filter_condition: %s",
            mode, filter_condition)
        # validate watchpoint_id

        mode_mapping = {
            'all': self._retrieve_all,
            'node': self._retrieve_node,
            'watchpoint': self._retrieve_watchpoint,
            'watchpoint_hit': self._retrieve_watchpoint_hit
        }
        # validate param <mode>
        if mode not in mode_mapping.keys():
            log.error(
                "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
                "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
            raise DebuggerParamTypeError("Invalid mode.")
        filter_condition = {} if filter_condition is None else filter_condition
        reply = mode_mapping[mode](filter_condition)

        return reply
Пример #6
0
 def get_node_name_by_full_name(self, full_name):
     """Get UI node name by full name."""
     if self._graph:
         node_name = self._graph.get_node_name_by_full_name(full_name)
     else:
         node_name = ''
         log.info("No graph received yet.")
     return node_name
Пример #7
0
 def search(self, name, watch_point_id):
     """Search for single node in graph."""
     log.info("receive search request for node:%s, in watchpoint:%d", name,
              watch_point_id)
     graph = self.cache_store.get_stream_handler(
         Streams.GRAPH).search_nodes(name)
     self.cache_store.get_stream_handler(
         Streams.WATCHPOINT).set_watch_nodes(graph, watch_point_id)
     return graph
Пример #8
0
 def _send_received_tensor_tag(self):
     """Send received_finish_tag."""
     node_name = self._received_view_cmd.get('node_name')
     if not node_name or self._received_view_cmd.get('wait_for_tensor'):
         return
     metadata = self._cache_store.get_stream_handler(Streams.METADATA).get()
     ret = {'receive_tensor': {'node_name': node_name}}
     ret.update(metadata)
     self._cache_store.put_data(ret)
     self._received_view_cmd.clear()
     log.info("Send receive tensor flag for %s", node_name)
Пример #9
0
    def _retrieve_all(self, filter_condition=None):
        """Retrieve metadata, root graph and watchpoint list."""
        if filter_condition:
            log.error("No filter condition required for retrieve all request.")
            raise DebuggerParamTypeError("filter_condition should be empty.")
        result = {}
        self._watch_point_id = 0
        self.cache_store.clean_data()
        log.info("Clean data queue cache when retrieve all request.")
        for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
            sub_res = self.cache_store.get_stream_handler(stream).get()
            result.update(sub_res)

        return result
Пример #10
0
 def SendGraph(self, request_iterator, context):
     """Send graph into DebuggerCache."""
     log.info("Received graph.")
     serial_graph = b""
     for chunk in request_iterator:
         serial_graph += chunk.buffer
     graph = GraphProto.FromString(serial_graph)
     log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
     self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
     self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
     self._status = ServerStatus.RECEIVE_GRAPH
     reply = get_ack_reply()
     log.info("Send the reply for graph.")
     return reply
Пример #11
0
    def retrieve_tensor_history(self, node_name):
        """
        Retrieve tensor history for leaf node.

        Args:
            node_name (str): The name of leaf node.

        Returns:
            dict, the tensor history and metadata.
        """
        log.info("Retrieve tensor history for node: %s.", node_name)
        self._validate_leaf_name(node_name)
        res = self._get_tensor_history(node_name)
        return res
Пример #12
0
    def put(self, value):
        """
        Put value into graph cache. Called by grpc server.

        Args:
            value (GraphProto): The Graph proto message.
        """
        self._graph_proto = value
        log.info("Put graph into cache.")

        # build graph
        graph = DebuggerGraph()
        graph.build_graph(value)
        self._graph = graph
        self.bfs_order = self._graph.get_bfs_order()
Пример #13
0
 def _update_metadata(self, metadata_stream, metadata_proto):
     """Update metadata."""
     # reset view round and clean cache data
     if metadata_stream.step < metadata_proto.cur_step:
         self._cache_store.clean_data()
         self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
             metadata_proto.cur_step)
     # put new metadata into cache
     metadata_stream.put(metadata_proto)
     cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name(
         metadata_proto.cur_node) if metadata_proto.cur_node else ''
     metadata_stream.node_name = cur_node
     metadata = metadata_stream.get()
     self._cache_store.put_data(metadata)
     log.info("Put new metadata into data queue.")
Пример #14
0
    def _get_next_command(self):
        """Get next command."""
        self._pos, event = self._cache_store.get_command(self._pos)
        log.debug("Received event :%s", event)
        if event is None:
            return event
        if isinstance(event, dict):
            event = self._deal_with_view_cmd(event)
        elif event.HasField('run_cmd'):
            event = self._deal_with_run_cmd(event)
        elif event.HasField('exit'):
            self._cache_store.clean()
            log.info("Clean cache for exit cmd.")

        return event
Пример #15
0
    def _wait_for_next_command(self):
        """
        Wait for next command.

        Returns:
            EventReply, the command event.
        """
        log.info("Start to wait for command.")
        self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
        self._cache_store.put_data({'metadata': {'state': 'waiting'}})
        event = None
        while event is None and self._status == ServerStatus.WAITING:
            log.debug("Wait for %s-th command", self._pos)
            event = self._get_next_command()
        return event
Пример #16
0
    def _pre_process(self, request):
        """Send graph and metadata when WaitCMD first called."""
        metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
        if self._status == ServerStatus.RECEIVE_GRAPH:
            self._status = ServerStatus.WAITING
            metadata_stream.state = 'waiting'
            metadata = metadata_stream.get()
            self._cache_store.clean_command()
            res = self._cache_store.get_stream_handler(Streams.GRAPH).get()
            res.update(metadata)
            self._cache_store.put_data(res)
            log.info("Put graph into data queue.")

        if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node:
            # clean tensor cache and DataQueue at the beginning of each step
            self._update_metadata(metadata_stream, request)
Пример #17
0
 def parse_shape(shape):
     """Parse shape."""
     if shape is None:
         return shape
     if not (isinstance(shape, str) and shape.startswith('[')
             and shape.endswith(']')):
         log.error("Invalid shape. Received: %s", shape)
         raise DebuggerParamValueError("Invalid shape.")
     shape = shape.strip('[]')
     if shape.count(':') > 2:
         log.error("Invalid shape. At most two dimensions are specified.")
         raise DebuggerParamValueError("Invalid shape.")
     parsed_shape = tuple(
         str_to_slice_or_int(dim)
         for dim in shape.split(',')) if shape else tuple()
     log.info("Parsed shape: %s from %s", parsed_shape, shape)
     return parsed_shape
Пример #18
0
 def SendTensors(self, request_iterator, context):
     """Send tensors into DebuggerCache."""
     log.info("Received tensor.")
     tensor_construct = []
     tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
     metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
     tensor_names = []
     step = metadata_stream.step
     for tensor in request_iterator:
         tensor_construct.append(tensor)
         if tensor.finished:
             if self._received_view_cmd.get('wait_for_tensor') and tensor.tensor_content:
                 self._received_view_cmd['wait_for_tensor'] = False
             tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
             tensor_construct = []
             tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
             continue
     reply = get_ack_reply()
     return reply
Пример #19
0
    def _get_tensor_infos_of_node(cur_node, slot=None):
        """Get tensors info of specified node."""
        tensors_info = []
        if slot is None:
            slots = range(cur_node.output_nums)
        elif slot >= 0:
            slots = [slot]
        else:
            log.info("Skip get tensor info for %s:%s.", cur_node.name, slot)
            return tensors_info
        for num in slots:
            tensor_info = {
                'name': cur_node.name + ':' + str(num),
                'full_name': cur_node.full_name + ':' + str(num),
                'node_type': cur_node.type
            }
            tensors_info.append(tensor_info)

        return tensors_info
Пример #20
0
    def retrieve_tensor_value(self, name, detail, shape):
        """Retrieve the tensor value."""
        log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s",
                 name, detail, shape)
        self.validate_tensor_param(name, detail)
        parsed_shape = self.parse_shape(shape)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name)
        reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({
            'name':
            tensor_name,
            'node_type':
            node_type,
            'shape':
            parsed_shape
        })
        reply['tensor_value']['name'] = name

        return reply
Пример #21
0
    def create_watchpoint(self,
                          watch_condition,
                          watch_nodes=None,
                          watch_point_id=None):
        """
        Create watchpoint.

        Args:
            watch_condition (dict): The watch condition.

                - condition (str): Accept `INF` or `NAN`.

                - param (list[float]): Not defined yet.
            watch_nodes (list[str]): The list of node names.
            watch_point_id (int): The id of watchpoint.

        Returns:
            dict, the id of new watchpoint.
        """
        log.info("Received create watchpoint request. WatchCondition: %s",
                 watch_condition)
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error(
                "Failed to create watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerCreateWatchPointError(
                "Failed to create watchpoint as the MindSpore is not in waiting state."
            )
        if metadata_stream.backend == 'GPU' and watch_condition.get(
                'condition') == 'OVERFLOW':
            log.error("GPU doesn't support OVERFLOW watch condition.")
            raise DebuggerParamValueError(
                "GPU doesn't support OVERFLOW watch condition.")

        watch_nodes = self._get_node_basic_infos(watch_nodes)
        watch_point_id = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).create_watchpoint(watch_condition, watch_nodes,
                                                  watch_point_id)
        self._watch_point_id = 0
        log.info("Create watchpoint %d", watch_point_id)
        return {'id': watch_point_id}
Пример #22
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.

                - tensor_protos (list[TensorProto]): The tensor proto.
        """
        tensor_protos = value.get('tensor_protos')
        merged_tensor = self._get_merged_tensor(tensor_protos)
        step = value.get('step', 0)
        if merged_tensor.iter and step > 0:
            log.debug("Received previous tensor.")
            step -= 1
        tensor = OpTensor(merged_tensor, step)
        self._put_tensor_into_cache(tensor, step)
        log.info("Put tensor %s of step: %d, into cache", tensor.name, step)
Пример #23
0
    def SendMetadata(self, request, context):
        """Send metadata into DebuggerCache."""
        log.info("Received Metadata.")
        if self._status != ServerStatus.PENDING:
            log.info("Re-initialize cache store when new session comes.")
            self.init()

        client_ip = context.peer().split(':', 1)[-1]
        metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
        if request.training_done:
            log.info("The training from %s has finished.", client_ip)
        else:
            metadata_stream.put(request)
            metadata_stream.client_ip = client_ip
            log.info("Put new metadata from %s into cache.", client_ip)
        # put metadata into data queue
        metadata = metadata_stream.get()
        self._cache_store.put_data(metadata)
        reply = get_ack_reply()
        log.info("Send the reply to %s.", client_ip)
        return reply
Пример #24
0
 def start(self):
     """Start server."""
     grpc_port = self.grpc_port if self.grpc_port else "50051"
     host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
     hostname = "{}:{}".format(host, grpc_port)
     # initialize a grpc server
     grpc_server_manager = grpc.server(
         futures.ThreadPoolExecutor(max_workers=10))
     grpc_server_base.add_EventListenerServicer_to_server(
         self.grpc_server, grpc_server_manager)
     grpc_server_manager.add_insecure_port(hostname)
     grpc_server_manager.start()
     my_server_thread = Thread(
         target=grpc_server_manager.wait_for_termination)
     # start grpc server
     my_server_thread.start()
     self.back_server = my_server_thread
     self.grpc_server_manager = grpc_server_manager
     # register stop server handler
     signal.signal(signal.SIGINT, self._stop_handler)
     log.info("Start grpc server %s", hostname)
Пример #25
0
    def WaitCMD(self, request, context):
        """Wait for a command in DebuggerCache."""
        # check if graph have already received.
        log.info("Received WaitCMD at %s-th step.", request.cur_step)
        if self._status == ServerStatus.PENDING:
            log.warning("No graph received before WaitCMD.")
            reply = get_ack_reply(1)
            return reply
        self._send_received_tensor_tag()
        # send graph if has not been sent before
        self._pre_process(request)
        # deal with old command
        reply = self._deal_with_old_command()
        if reply:
            log.info("Reply to WaitCMD with old command: %s", reply)
            return reply
        # continue multiple steps training
        if self._continue_steps:
            reply = get_ack_reply()
            reply.run_cmd.run_steps = 1
            reply.run_cmd.run_level = 'step'
            self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
            self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
            log.debug("Send RunCMD. Clean watchpoint hit.")
        # wait for command
        else:
            reply = self._wait_for_next_command()

        if reply is None:
            reply = get_ack_reply(1)
            log.warning("Failed to get command event.")
        else:
            log.info("Reply to WaitCMD: %s", reply)
        return reply
Пример #26
0
 def SendWatchpointHits(self, request_iterator, context):
     """Send watchpoint hits info DebuggerCache."""
     log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
     self._continue_steps = 0
     watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
     watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
     graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
     for watchpoint_hit_proto in request_iterator:
         ui_node_name = graph_stream.get_node_name_by_full_name(
             watchpoint_hit_proto.tensor.node_name)
         log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
         if not ui_node_name:
             log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
             continue
         watchpoint_hit = {
             'tensor_proto': watchpoint_hit_proto.tensor,
             'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
             'node_name': ui_node_name
         }
         watchpoint_hit_stream.put(watchpoint_hit)
     watchpoint_hits_info = watchpoint_hit_stream.get()
     self._cache_store.put_data(watchpoint_hits_info)
     log.info("Send the watchpoint hits to DataQueue.\nSend the reply.")
     reply = get_ack_reply()
     return reply
Пример #27
0
    def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None):
        """
        Update watchpoint.

        Args:
            watch_point_id (int): The id of watchpoint.
            watch_nodes (list[str]): The list of node names.
            mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
                1 for add nodes to watch nodes.
            name (str): The search name. Default: None.

        Returns:
            dict, empty response.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to update watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerUpdateWatchPointError(
                "Failed to update watchpoint as the MindSpore is not in waiting state."
            )
        # validate
        if not watch_nodes or not watch_point_id:
            log.error("Invalid parameter for update watchpoint.")
            raise DebuggerParamValueError(
                "Invalid parameter for update watchpoint.")
        # update watch node
        if name is not None:
            watch_nodes = self._get_watch_nodes_by_search(watch_nodes)
        elif mode == 1:
            watch_nodes = self._get_node_basic_infos(watch_nodes)

        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).update_watchpoint(watch_point_id, watch_nodes,
                                                  mode)
        self._watch_point_id = watch_point_id
        log.info("Update watchpoint with id: %d", watch_point_id)
        return {}
Пример #28
0
    def delete_watchpoint(self, watch_point_id):
        """
        Delete watchpoint.

        Args:
            watch_point_id (int): The id of watchpoint.

        Returns:
            dict, empty response.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to delete watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerDeleteWatchPointError(
                "Failed to delete watchpoint as the MindSpore is not in waiting state."
            )
        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).delete_watchpoint(watch_point_id)
        self._watch_point_id = 0
        log.info("Delete watchpoint with id: %d", watch_point_id)
        return {}
Пример #29
0
    def get_bfs_order(self):
        """
        Traverse the graph in order of breath-first search.

        Returns:
            list, including the leaf nodes arranged in BFS order.
        """
        root = self.get_default_root()
        log.info('Randomly choose node %s as root to do BFS.', root.name)

        bfs_order = []
        self.get_bfs_graph(root.name, bfs_order)
        length = len(self._leaf_nodes.keys())
        # Find rest un-traversed nodes
        for node_name, _ in self._leaf_nodes.items():
            if node_name not in bfs_order:
                self.get_bfs_graph(node_name, bfs_order)

        if len(bfs_order) != length:
            log.error("The length of bfs and leaf nodes are not equal.")
            msg = "Not all nodes are traversed!"
            raise DebuggerParamValueError(msg)

        return bfs_order
Пример #30
0
 def stop(self):
     """Stop debugger server."""
     self.grpc_server_manager.stop(grace=None)
     self.back_server.join()
     log.info("Stop debugger server.")