예제 #1
0
    def _retrieve_watchpoint_hit(self, filter_condition):
        """
        Retrieve watchpoint hit.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        node_name = filter_condition.get('name')
        # get watchpoint hit list
        if node_name is None:
            reply = self.cache_store.get_stream_handler(
                Streams.WATCHPOINT_HIT).get()
            return reply

        self._validate_leaf_name(node_name)
        # get tensor history
        reply = self._get_tensor_history(node_name)
        log.debug("Get tensor history for watchpoint hit node.")
        # get single graph
        if filter_condition.get('single_node'):
            graph = self._get_nodes_info(filter_condition)
            reply.update(graph)
        log.debug("Get tensor history for watchpoint hit node.")

        return reply
예제 #2
0
    def WaitCMD(self, request, context):
        """Wait for a command in DebuggerCache."""
        # check if graph have already received.
        log.info("Received WaitCMD at %s-th step.", request.cur_step)
        if self._status == ServerStatus.PENDING:
            log.warning("No graph received before WaitCMD.")
            reply = get_ack_reply(1)
            return reply
        self._send_received_tensor_tag()
        # send graph if has not been sent before
        self._pre_process(request)
        # deal with old command
        reply = self._deal_with_old_command()
        if reply:
            log.info("Reply to WaitCMD with old command: %s", reply)
            return reply
        # continue multiple steps training
        if self._continue_steps:
            reply = get_ack_reply()
            reply.run_cmd.run_steps = 1
            reply.run_cmd.run_level = 'step'
            self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
            self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
            log.debug("Send RunCMD. Clean watchpoint hit.")
        # wait for command
        else:
            reply = self._wait_for_next_command()

        if reply is None:
            reply = get_ack_reply(1)
            log.warning("Failed to get command event.")
        else:
            log.info("Reply to WaitCMD: %s", reply)
        return reply
예제 #3
0
    def get_tensor_history(self, node_name, depth=0):
        """
        Get the tensor history of a specified node.

        Args:
            node_name (str): The debug name of the node.
            depth (int): The number of layers the user
                wants to trace. Default is 0.

        Returns:
            dict, basic tensor history, only including tensor name and tensor type and node type.
        """
        self._graph_exists()
        if not self._graph.exist_node(node_name):
            raise DebuggerNodeNotInGraphError(node_name)

        tensor_history, cur_outputs_nums = self._graph.get_tensor_history(
            node_name, depth)
        # add the tensor type for tensor history
        self._update_tensor_history(tensor_history[0:cur_outputs_nums],
                                    'output')
        self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input')
        log.debug("Get %d tensors in tensor history for node <%s>.",
                  len(tensor_history), node_name)
        return {'tensor_history': tensor_history}
예제 #4
0
    def _retrieve_watchpoint(self, filter_condition):
        """
        Retrieve watchpoint.

        Args:
            filter_condition (dict): Filter condition.

                - watch_point_id (int):  The id of watchoint. If not given, return all watchpoints.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        watchpoint_id = filter_condition.get('watch_point_id')
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watchpoint_id)
        self._watch_point_id = watchpoint_id if watchpoint_id else 0
        if not watchpoint_id:
            reply = self.cache_store.get_stream_handler(
                Streams.WATCHPOINT).get()
            log.debug("Get condition of watchpoints.")
        else:
            reply = self._retrieve_node(filter_condition)
            log.debug("Get graph of %d-th watchpoint.", watchpoint_id)

        return reply
예제 #5
0
    def add_node(self, node_name, node_type, full_name=''):
        """
        Add watch node to watch node tree.

        Args:
            node_name (str): The node name.
            node_type (str): The node type.
            full_name (str): The full name of node.
        """
        log.debug("Add node %s with type: %s, full_name: %s", node_name,
                  node_type, full_name)
        scope_names = node_name.split('/', 1)
        if len(scope_names) == 1:
            if not self.get(node_name):
                self.add(node_name,
                         node_type,
                         full_name,
                         watch_status=WatchNodeTree.TOTAL_WATCH)
            else:
                self.get(node_name).enable_watch_status()
            return

        scope_name, sub_names = scope_names
        sub_tree = self.get(scope_name)
        if not sub_tree:
            sub_tree = self.add(scope_name, watch_status=1)
        sub_tree.add_node(sub_names, node_type, full_name)
예제 #6
0
 def SendWatchpointHits(self, request_iterator, context):
     """Send watchpoint hits info DebuggerCache."""
     log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
     self._continue_steps = 0
     watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
     watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
     graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
     for watchpoint_hit_proto in request_iterator:
         ui_node_name = graph_stream.get_node_name_by_full_name(
             watchpoint_hit_proto.tensor.node_name)
         log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
         if not ui_node_name:
             log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
             continue
         watchpoint_hit = {
             'tensor_proto': watchpoint_hit_proto.tensor,
             'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
             'node_name': ui_node_name
         }
         watchpoint_hit_stream.put(watchpoint_hit)
     watchpoint_hits_info = watchpoint_hit_stream.get()
     self._cache_store.put_data(watchpoint_hits_info)
     log.info("Send the watchpoint hits to DataQueue.\nSend the reply.")
     reply = get_ack_reply()
     return reply
예제 #7
0
    def get(self, filter_condition=False):
        """
        Get the watchpoints.

        Args:
            filter_condition (bool): If True, get all watchpoints without nodes. If False,
                get updated watchpoints in SetCMD proto format. Default: False.

        Returns:
            dict, the watchpoints.
        """
        reply = []
        if not filter_condition:
            # get watch condition list
            for _, watchpoint in self._watchpoints.items():
                watchpoint_info = watchpoint.get_watch_condition_info()
                reply.append(watchpoint_info)
        else:
            # get updated watchpoint list
            for _, watchpoint in self._updated_watchpoints.items():
                set_cmd = watchpoint.get_set_cmd()
                reply.append(set_cmd)
            reply.extend(self._deleted_watchpoints)

        log.debug("get the watch points with filter_condition:%s",
                  filter_condition)

        return {'watch_points': reply}
예제 #8
0
    def put(self, value):
        """
        Put value into event_cache.

        Args:
            value (dict): The event to be put into cache.
        """
        if not isinstance(value, dict):
            log.error("Dict type required when put event message.")
            raise DebuggerParamValueError("Dict type required when put event message.")

        with self._lock:
            log.debug("Put the %d-th message into queue. \n %d requests is waiting.",
                      self._next_idx, len(self._pending_requests))
            cur_pos = self._next_idx
            # update next pos
            self._next_idx += 1
            if self._next_idx >= self.max_limit:
                self._next_idx = 0
                self._prev_flag = self._cur_flag
                self._cur_flag = str(uuid.uuid4())
            # set next pos
            if not value.get('metadata'):
                value['metadata'] = {}
            value['metadata']['pos'] = self.next_pos
            self._event_cache[cur_pos] = value
            # feed the value for pending requests
            self.clean_pending_requests(value)
예제 #9
0
    def _continue(self, metadata_stream, params):
        """
        Send RunCMD to MindSpore.

        Args:
            metadata_stream (MetadataHandler): The metadata_handler
            params (dict): The control params.
        """
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error("MindSpore is not ready to run. Current state is: %s",
                      metadata_stream.state)
            raise DebuggerContinueError(
                "MindSpore is not ready to run or is running currently.")
        metadata_stream.state = ServerStatus.RUNNING.value
        current_state = ServerStatus.RUNNING.value
        try:
            event = self._construct_run_event(params)
            self._send_watchpoints()
            self.cache_store.put_command(event)
        except MindInsightException as err:
            log.error("Failed to send run event.")
            log.exception(err)
            current_state = ServerStatus.WAITING.value
            metadata_stream.state = current_state
            raise DebuggerContinueError("Failed to send run command.")
        else:
            log.debug("Send the RunCMD to command queue.")

        return {'metadata': {'state': current_state}}
예제 #10
0
    def _deal_with_old_command(self):
        """Deal with old command."""
        event = None
        while self._cache_store.has_command(self._pos) and event is None:
            event = self._get_next_command()
            log.debug("Deal with old %s-th command:\n%s.", self._pos, event)

        return event
예제 #11
0
    def put_data(self, value):
        """
        Set updated data to data stream.

        Args:
            value (dict): The updated data.
        """
        log.debug("Set <%d> bytes data", sys.getsizeof(value))
        return self._put(Streams.DATA, value)
예제 #12
0
    def put_command(self, cmd):
        """
        Set command to command stream.

        Args:
            cmd (EventReply): The command EventReply.
        """
        log.debug("Set command %s", cmd)
        return self._put(Streams.COMMAND, {'cmd': cmd})
예제 #13
0
 def clean(self):
     """Clean event cache."""
     with self._lock:
         self._prev_flag = str(uuid.uuid4())
         self._cur_flag = str(uuid.uuid4())
         self._next_idx = 0
         self._event_cache = [None] * self.max_limit
         value = {'metadata': {'pos': '0'}}
         self.clean_pending_requests(value)
         log.debug("Clean event cache. %d request is waiting.", len(self._pending_requests))
예제 #14
0
 def _deal_with_view_cmd(self, event):
     """Deal with view cmd."""
     view_cmd = event.get('view_cmd')
     node_name = event.get('node_name')
     log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name)
     if not (view_cmd and node_name):
         log.warning("Invaid view command. Ignore it.")
         return None
     self._received_view_cmd['node_name'] = node_name
     self._received_view_cmd['wait_for_tensor'] = True
     return view_cmd
예제 #15
0
    def put(self, value):
        """
        Put value into metadata cache. Called by grpc server.

        Args:
            value (MetadataProto): The Metadata proto message.
        """
        self._device_name = value.device_name.split(':')[0]
        self._step = value.cur_step
        self._cur_full_name = value.cur_node
        self._backend = value.backend if value.backend else "Ascend"
        log.debug("Put metadata into cache at the %d-th step.", self._step)
예제 #16
0
    def put(self, value):
        """
        Put Watchpoint into watchpoint handler.

        Args:
            value (Watchpoint): The name of nodes that have been chosen.
        """
        new_id = value.watchpoint_id
        self._watchpoints[new_id] = value
        self._updated_watchpoints[new_id] = value
        self._latest_id = new_id
        log.debug("Put watchpoint %d into cache.", new_id)
예제 #17
0
    def _terminate(self, metadata_stream):
        """
        Terminate the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        metadata_stream.state = 'pending'
        event = get_ack_reply()
        event.exit = True
        self.cache_store.put_command(event)
        log.debug("Send the ExitCMD.")
        return {'metadata': {'state': 'pending'}}
예제 #18
0
 def _send_watchpoints(self):
     """Set watchpoints."""
     watchpoint_stream = self.cache_store.get_stream_handler(
         Streams.WATCHPOINT)
     watchpoints = watchpoint_stream.get(
         filter_condition=True).get('watch_points')
     if watchpoints:
         for watchpoint in watchpoints:
             event = get_ack_reply()
             event.set_cmd.CopyFrom(watchpoint)
             self.cache_store.put_command(event)
         watchpoint_stream.sync_set_cmd()
         log.debug("Send SetCMD to MindSpore. %s", event)
예제 #19
0
    def set_watch_nodes(self, graph, watch_point_id):
        """
        set watch nodes for graph.

        Args:
            graph (dict): The graph with list of nodes.
            watch_point_id (int): The id of watchpoint.
        """
        if not (watch_point_id and graph):
            return
        self.validate_watchpoint_id(watch_point_id)
        log.debug("add watch flags")
        watchpoint = self._watchpoints.get(watch_point_id)
        self._set_watch_status_recursively(graph, watchpoint)
예제 #20
0
 def SendGraph(self, request_iterator, context):
     """Send graph into DebuggerCache."""
     log.info("Received graph.")
     serial_graph = b""
     for chunk in request_iterator:
         serial_graph += chunk.buffer
     graph = GraphProto.FromString(serial_graph)
     log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
     self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
     self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
     self._status = ServerStatus.RECEIVE_GRAPH
     reply = get_ack_reply()
     log.info("Send the reply for graph.")
     return reply
예제 #21
0
    def _wait_for_next_command(self):
        """
        Wait for next command.

        Returns:
            EventReply, the command event.
        """
        log.info("Start to wait for command.")
        self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting'
        self._cache_store.put_data({'metadata': {'state': 'waiting'}})
        event = None
        while event is None and self._status == ServerStatus.WAITING:
            log.debug("Wait for %s-th command", self._pos)
            event = self._get_next_command()
        return event
예제 #22
0
    def _get_next_command(self):
        """Get next command."""
        self._pos, event = self._cache_store.get_command(self._pos)
        log.debug("Received event :%s", event)
        if event is None:
            return event
        if isinstance(event, dict):
            event = self._deal_with_view_cmd(event)
        elif event.HasField('run_cmd'):
            event = self._deal_with_run_cmd(event)
        elif event.HasField('exit'):
            self._cache_store.clean()
            log.info("Clean cache for exit cmd.")

        return event
예제 #23
0
    def _pause(self, metadata_stream):
        """
        Pause the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        if metadata_stream.state != ServerStatus.RUNNING.value:
            log.error("The MindSpore is not running.")
            raise DebuggerPauseError("The MindSpore is not running.")
        metadata_stream.state = 'waiting'
        event = get_ack_reply()
        event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
        self.cache_store.put_command(event)
        log.debug("Send the Pause command")
        return {'metadata': {'state': 'waiting'}}
예제 #24
0
    def _get_watch_nodes_by_search(self, watch_nodes):
        """Get watched leaf nodes by search name."""
        watched_leaf_nodes = []
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        for search_name in watch_nodes:
            search_nodes = graph_stream.get_searched_node_list()
            search_node_names = [
                NodeBasicInfo(name=node.name,
                              full_name=node.full_name,
                              type=node.type) for node in search_nodes
                if node.name.startswith(search_name)
            ]
            watched_leaf_nodes.extend(search_node_names)

        log.debug("Update nodes: %s", watched_leaf_nodes)

        return watched_leaf_nodes
예제 #25
0
    def _wait_for_event(self, cur_id, cur_queue, pos):
        """Wait for the pos-th event."""
        try:
            # set the timeout to 25 seconds which is less the the timeout limit from UI
            event = cur_queue.get(timeout=25)
        except Empty:
            event = None

        if event is None:
            with self._lock:
                if self._pending_requests.get(cur_id):
                    self._pending_requests.pop(cur_id)
                log.debug("Clean timeout request. Left pending requests: %d",
                          len(self._pending_requests))
            event = {'metadata': {'pos': pos}}

        return event
예제 #26
0
    def _deal_with_run_cmd(self, event):
        """Deal with run cmd."""
        run_cmd = event.run_cmd
        # receive step command
        if run_cmd.run_level == 'step':
            # receive pause cmd
            if run_cmd.run_steps == 0:
                log.debug("Pause training and wait for next command.")
                self._continue_steps = 0
                return None
            # receive step cmd
            self._continue_steps = run_cmd.run_steps - 1
            event.run_cmd.run_steps = 1
        self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
        log.debug("Receive RunCMD. Clean watchpoint hit cache.")

        return event
예제 #27
0
    def delete_watchpoint(self, watch_point_id):
        """
        Delete watchpoint.

        Args:
            watch_point_id (int): The id of watchpoint.

        Returns:
            dict, empty response.
        """
        self.validate_watchpoint_id(watch_point_id)
        self._watchpoints.pop(watch_point_id)
        set_cmd = SetCMD()
        set_cmd.id = watch_point_id
        set_cmd.delete = True
        self._deleted_watchpoints.append(set_cmd)
        log.debug("Delete watchpoint %d in cache.", watch_point_id)
예제 #28
0
    def get(self, filter_condition=None):
        """
        Get watchpoint hit list.

        Args:
            filter_condition (str): Get the watchpoint hit according to specifiled node name.
                If not given, get all watchpoint hits. Default: None.

        Returns:
            dict, the watchpoint hit list.
        """
        if filter_condition is None:
            log.debug("Get all watchpoint hit list.")
            reply = self.get_watchpoint_hits()
        else:
            log.debug("Get the watchpoint for node: <%s>.", filter_condition)
            reply = self._hits.get(filter_condition)

        return reply
예제 #29
0
    def _add_tensor_value_for_tensor_history(self, tensor_history, node_name):
        """
        Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

        Args:
            tensor_history (list[dict]): A list of tensor info, including name and type.
            node_name (str): The UI node name.

        Returns:
            dict, the tensor info.
        """
        tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
        missed_tensors = tensor_stream.update_tensor_history(tensor_history)
        if missed_tensors:
            view_cmd = create_view_event_from_tensor_history(missed_tensors)
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'node_name': node_name
            })
            log.debug("Send view cmd.")
예제 #30
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.

                - tensor_protos (list[TensorProto]): The tensor proto.
        """
        tensor_protos = value.get('tensor_protos')
        merged_tensor = self._get_merged_tensor(tensor_protos)
        step = value.get('step', 0)
        if merged_tensor.iter and step > 0:
            log.debug("Received previous tensor.")
            step -= 1
        tensor = OpTensor(merged_tensor, step)
        self._put_tensor_into_cache(tensor, step)
        log.info("Put tensor %s of step: %d, into cache", tensor.name, step)