def get(self, filter_condition=None, rank_id=0):
     """Get the graph of specific node for specific device."""
     if rank_id in self.watchpoint_hit_handlers:
         return self.watchpoint_hit_handlers.get(rank_id).get(
             filter_condition)
     log.error("There is no rank id %d.", rank_id)
     raise ValueError
Пример #2
0
    def get(self, filter_condition=None):
        """
        Get the watchpoints.

        Args:
            filter_condition (Union[None, int]): The filter conditions. Get watchpoint by
                id. If None, return all watchpoint. Default: None.

        Returns:
            dict, the watchpoint list.
        """
        reply = []
        if not filter_condition:
            # get watch condition list
            for _, watchpoint in self._watchpoints.items():
                watchpoint_info = watchpoint.get_watch_condition_info()
                reply.append(watchpoint_info)
        else:
            self.validate_watchpoint_id(filter_condition)
            reply = [self._watchpoints.get(filter_condition)]

        log.debug("get the watch points with filter_condition:%s",
                  filter_condition)

        return {'watch_points': reply}
Пример #3
0
    def _get_watch_nodes_by_search(self, node_names, search_pattern,
                                   graph_name):
        """
        Get watched leaf nodes by search name.

        Args:
            node_names (list[str]): A list of node names.
            search_pattern (dict): Get watch node with search pattern.

                - name (str): The name pattern.
                - node_category (str): The node_category.
            graph_name (str): The relative graph_name of the watched node.

        Returns:
            list[NodeBasicInfo], a list of node basic infos.
        """
        search_pattern['graph_name'] = graph_name
        search_nodes = self._graph_stream.search_nodes(search_pattern)
        watch_node_names = set()
        for name in node_names:
            names = self._get_watch_names_by_search(search_nodes, name)
            watch_node_names.update(names)
        watch_node_info = self._get_node_basic_infos(watch_node_names,
                                                     graph_name=graph_name)
        log.debug("Update nodes: %s", watch_node_info)

        return watch_node_info
Пример #4
0
    def _retrieve_node(self, filter_condition):
        """
        Retrieve node info.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.
                - graph_name (str): The relative graph_name of the node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.
                - watch_point_id (int): The id of watchpoint.

        Returns:
            dict, reply with graph.
        """
        log.debug("Retrieve node %s.", filter_condition)
        # validate node name
        node_name = filter_condition.get('name')
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        if node_name:
            # validate node name
            graph_stream.get_node_type(node_name, graph_name)
        filter_condition['single_node'] = bool(
            filter_condition.get('single_node'))
        filter_condition['graph_name'] = graph_name
        reply = self._get_nodes_info(filter_condition)
        return reply
Пример #5
0
    def search_watchpoint_hits(self, group_condition):
        """
        Retrieve watchpoint hit.

        Args:
            group_condition (dict): Filter condition.

                - limit (int): The limit of each page.
                - offset (int): The offset of current page.
                - node_name (str): The retrieved node name.
                - graph_name (str): The retrieved graph name.

        Returns:
            dict, watch point list or relative graph.
        """
        if not isinstance(group_condition, dict):
            log.error(
                "Group condition for watchpoint-hits request should be a dict")
            raise DebuggerParamTypeError(
                "Group condition for watchpoint-hits request should be a dict")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        reply = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT).group_by(group_condition)
        reply['outdated'] = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).is_recheckable()
        return reply
Пример #6
0
    def add_node(self, node_name, node_type, full_name=''):
        """
        Add watch node to watch node tree.

        Args:
            node_name (str): The node name.
            node_type (str): The node type.
            full_name (str): The full name of node.
        """
        log.debug("Add node %s with type: %s, full_name: %s", node_name,
                  node_type, full_name)
        scope_names = node_name.split('/', 1)
        if len(scope_names) == 1:
            target_node = self.get(node_name)
            if not target_node:
                self.add(node_name,
                         node_type,
                         full_name,
                         watch_status=WatchNodeTree.TOTAL_WATCH)
            else:
                target_node.update_metadata(node_type, full_name,
                                            WatchNodeTree.TOTAL_WATCH)
            return

        scope_name, sub_names = scope_names
        sub_tree = self.get(scope_name)
        if not sub_tree:
            sub_tree = self.add(scope_name, watch_status=1)
        sub_tree.add_node(sub_names, node_type, full_name)
Пример #7
0
 def stop(self):
     """Stop debugger server."""
     log.info("Send terminate info to client.")
     self.control({'mode': 'terminate'})
     self.grpc_server_manager.stop(grace=None)
     self.back_server.join()
     log.info("Stop debugger server.")
    def _save_watchpoint_hits(self, hits):
        """Save watchpoint hits."""
        multi_card_hit_streams = self._cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT)
        multi_card_graph_streams = self._cache_store.get_stream_handler(
            Streams.GRAPH)
        watchpoint_stream = self._cache_store.get_stream_handler(
            Streams.WATCHPOINT)

        watchpoint_hits = defaultdict(list)
        for hit in hits:
            log.info(
                "Received hit\n: "
                "name:%s, slot:%s, condition:%s, "
                "watchpoint_id:%s"
                "error_code:%s, rank_id:%s", hit['name'], hit['slot'],
                hit['condition'], hit['watchpoint_id'], hit['error_code'],
                hit['rank_id'])
            rank_id = hit['rank_id']
            watchpoint_hit = {}
            self._add_hit_node_info(watchpoint_hit, multi_card_graph_streams,
                                    rank_id, hit)
            if not watchpoint_hit:
                continue
            self._add_hit_watchpoint_info(watchpoint_hit, watchpoint_stream,
                                          hit)
            watchpoint_hit['error_code'] = hit['error_code']
            watchpoint_hits[rank_id].append(watchpoint_hit)
        # save hit info into cache
        multi_card_hit_streams.put(watchpoint_hits)
        self._cache_store.put_data({'receive_watchpoint_hits': True})
        log.debug("Send the watchpoint hits to DataQueue.")
 def _deal_with_run_cmd(self, event):
     """Deal with run cmd."""
     metadata_stream = self._cache_store.get_stream_handler(
         Streams.METADATA)
     run_cmd = event.run_cmd
     # receive step command
     if run_cmd.run_level == RunLevel.STEP.value:
         # receive pause cmd
         if not run_cmd.run_steps:
             log.debug("Pause training and wait for next command.")
             self._old_run_cmd.clear()
             # update metadata state from sending to waiting
             metadata_stream.state = ServerStatus.WAITING.value
             return None
         # receive step cmd
         left_steps = run_cmd.run_steps - 1
         event.run_cmd.run_steps = 1
         if left_steps:
             self._old_run_cmd[
                 'left_step_count'] = left_steps if left_steps > 0 else -1
     elif run_cmd.node_name:
         self._old_run_cmd['node_name'] = run_cmd.node_name
         run_cmd.node_name = ''
     # clean watchpoint hit cache
     if run_cmd.run_level == RunLevel.RECHECK.value:
         self._cache_store.get_stream_handler(
             Streams.WATCHPOINT_HIT).clean()
         log.debug("Receive RunCMD. Clean watchpoint hit cache.")
     # update metadata state from sending to running
     metadata_stream.state = ServerStatus.RUNNING.value
     return event
Пример #10
0
    def _get_by_offset(self, group_condition):
        """Return the list of watchpoint hits on the offset page."""
        limit = group_condition.get('limit')
        offset = group_condition.get('offset')
        if not isinstance(limit, int) or not isinstance(offset, int):
            log.error("Param limit or offset is not a integer")
            raise DebuggerParamValueError(
                "Param limit or offset is not a integer")
        watch_point_hits = []

        total = len(self._ordered_hits)

        if limit * offset >= total and offset != 0:
            log.error("Param offset out of bounds")
            raise DebuggerParamValueError("Param offset out of bounds")

        if total == 0:
            return {}

        for watchpoint_hits in self._ordered_hits[(limit *
                                                   offset):(limit *
                                                            (offset + 1))]:
            self._get_tensors(watchpoint_hits, watch_point_hits)

        return {
            'watch_point_hits': watch_point_hits,
            'offset': offset,
            'total': total
        }
 def run(self):
     """Start the debugger offline server."""
     log.info("Initialize Offline Debugger Server for dbg_dir: %s",
              self._context.dbg_dir)
     self._offline_server_manager.initialize()
     log.info("Start Offline Debugger Server for dbg_dir: %s",
              self._context.dbg_dir)
     self._running.set()
     try_count = 0
     while self._running.is_set(
     ) and try_count < self._MAX_TRY_EXCEPT_COUNT:
         try:
             self._offline_server_manager.wait_for_termination()
             if not self._offline_server_manager.is_runnable():
                 break
         except MindInsightException as err:
             log.exception(err)
             log.warning(
                 "Error happens during listening on user commands. Restart listening again."
             )
         finally:
             try_count += 1
     # protect server from too much failure commands.
     if try_count == self._MAX_TRY_EXCEPT_COUNT:
         self._cache_store.clean()
         metadata = self._cache_store.get_stream_handler(
             Streams.METADATA).get()
         self._cache_store.put_data(metadata)
         log.warning("Exception exceed %d times, stop server.", try_count)
Пример #12
0
    def _get_missing_tensor_info(self, tensor_name, node_type):
        """
        Get missing tensor infos.

        Args:
            tensor_name (str): The full name of Tensor.
            node_type (str): The type of the relative node.

        Returns:
            list, list of missing tensor basic information.
        """
        step = self.cur_step
        missing_tensors_info = []
        # check the current step value is missing
        if self._is_tensor_value_missing(tensor_name, step):
            missing_tensors_info.append(
                TensorBasicInfo(full_name=tensor_name,
                                node_type=node_type,
                                iter=''))
            log.debug("Add current step view cmd for %s", tensor_name)
        # check the previous step value is missing
        if node_type == NodeTypeEnum.PARAMETER.value and self._is_tensor_value_missing(
                tensor_name, step - 1):
            missing_tensors_info.append(
                TensorBasicInfo(full_name=tensor_name,
                                node_type=node_type,
                                iter='prev'))
            log.debug("Add previous view cmd for %s", tensor_name)
        return missing_tensors_info
Пример #13
0
    def _put_tensor_into_cache(self, tensor, step):
        """
        Put tensor into cache.

        Args:
            tensor (OpTensor): The tensor value.
            step (int): The step of tensor.

        Returns:
            bool, the tensor has updated successfully.
        """
        cache_tensor = self._tensors.get(tensor.name)
        if cache_tensor is None:
            cache_tensor = {}
            self._tensors[tensor.name] = cache_tensor

        old_tensor = cache_tensor.get(step)
        if old_tensor and not self._is_value_diff(old_tensor.value,
                                                  tensor.value):
            log.debug("Tensor %s of step %s has no change. Ignore it.",
                      tensor.name, step)
            return False
        cache_tensor[step] = tensor
        log.debug("Put updated tensor value for %s of step %s.", tensor.name,
                  step)
        return True
Пример #14
0
    def get_bfs_graph(self, node_name, bfs_order):
        """
        Traverse the graph in order of breath-first search.

        Returns:
            list, including the leaf nodes arranged in BFS order.
        """
        temp_list = deque()
        temp_list.append(node_name)
        while temp_list:
            node_name = temp_list.popleft()
            node = self._leaf_nodes.get(node_name)

            if not node:
                log.warning('Cannot find node %s in graph. Ignored.',
                            node_name)
                continue

            bfs_order.append(node_name)
            if node.inputs:
                for name in node.inputs.keys():
                    if name not in temp_list and name not in bfs_order:
                        temp_list.append(name)
            if node.outputs:
                for name in node.outputs.keys():
                    if name not in temp_list and name not in bfs_order:
                        temp_list.append(name)
Пример #15
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.
                - tensor_proto (TensorProto): The tensor proto.
                - tensor_contents (list[byte]): The list of tensor content values.

        Returns:
            bool, the tensor has updated successfully.
        """
        tensor = self._deal_with_tensor(value)
        stats = None
        if value.get('stats',
                     False) and tensor.status == TensorStatusEnum.CACHED.value:
            tensor.calculate_stats()
            stats = tensor.stats

        flag = self._put_tensors(tensor)
        new_tensor = self._tensors.get(tensor.name).get(tensor.step)
        new_tensor.stats = stats
        log.info("Put tensor %s of step: %d, into cache. Flag: %s",
                 tensor.name, tensor.step, flag)
        return flag
Пример #16
0
    def _deal_with_run_cmd(self, event):
        """Deal with run cmd."""
        run_cmd = event.run_cmd
        # receive step command
        if run_cmd.run_level == 'step':
            # receive pause cmd
            if not run_cmd.run_steps:
                log.debug("Pause training and wait for next command.")
                self._old_run_cmd.clear()
                return None
            # receive step cmd
            left_steps = run_cmd.run_steps - 1
            event.run_cmd.run_steps = 1
            if left_steps:
                self._old_run_cmd[
                    'left_step_count'] = left_steps if left_steps > 0 else -1
        elif run_cmd.node_name:
            self._old_run_cmd['node_name'] = run_cmd.node_name
            run_cmd.node_name = ''
        # clean watchpoint hit cache
        if run_cmd.run_level == RunLevel.RECHECK.value:
            self._cache_store.get_stream_handler(
                Streams.WATCHPOINT_HIT).clean()
        log.debug("Receive RunCMD. Clean watchpoint hit cache.")

        return event
Пример #17
0
    def get(self, filter_condition=None):
        """
        Get full tensor value.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The full name of tensor.
                - node_type (str): The type of the node.
                - prev (bool): Whether to get previous tensor.

        Returns:
            dict, the tensor_value and whether need to send view_command.
        """
        name = filter_condition.get('name')
        node_type = filter_condition.get('node_type')
        shape = filter_condition.get('shape')
        if filter_condition.get('prev'):
            step = self.prev_step
        else:
            step = self.cur_step
        tensor = self._get_tensor(name, node_type, step)
        if not tensor:
            log.error("No tensor named %s at the step %s", name, step)
            raise DebuggerParamValueError("No tensor named {}".format(name))
        tensor_info = tensor.get_full_info(shape)
        self._update_has_prev_step_field(tensor_info, name, node_type,
                                         self.cur_step)
        res = {'tensor_value': tensor_info, 'view_cmd': False}
        if tensor.status == TensorStatusEnum.UNCACHED.value:
            self._add_hold_value_tensors(name, step)
            res['view_cmd'] = True
        return res
Пример #18
0
 def SendTensors(self, request_iterator, context):
     """Send tensors into DebuggerCache."""
     log.info("Received tensor.")
     tensor_construct = []
     tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
     metadata_stream = self._cache_store.get_stream_handler(
         Streams.METADATA)
     tensor_names = []
     step = metadata_stream.step
     for tensor in request_iterator:
         tensor_construct.append(tensor)
         if tensor.finished:
             update_flag = tensor_stream.put({
                 'step':
                 step,
                 'tensor_protos':
                 tensor_construct
             })
             if self._received_view_cmd.get(
                     'wait_for_tensor') and update_flag:
                 # update_flag is used to avoid querying empty tensors again
                 self._received_view_cmd['wait_for_tensor'] = False
                 log.debug("Set wait for tensor flag to False.")
             tensor_construct = []
             tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
             continue
     reply = get_ack_reply()
     return reply
Пример #19
0
    def create_session(self, session_type, train_job=None):
        """
        Create the session by the train job info or session type if the session doesn't exist.

        Args:
            session_type (str): The session_type.
            train_job (str): The train job info.

        Returns:
            str, session id.
        """
        with self._lock:
            if self._exiting:
                logger.info("System is exiting, will terminate the thread.")
                _thread.exit()

            if session_type == self.ONLINE_TYPE:
                if self.ONLINE_SESSION_ID not in self.sessions:
                    logger.error(
                        'Online session is unavailable, set --enable-debugger as true/1 to enable debugger '
                        'when start Mindinsight server.')
                    raise DebuggerOnlineSessionUnavailable()
                return self.ONLINE_SESSION_ID

            if train_job in self.train_jobs:
                return self.train_jobs.get(train_job)

            return self._create_offline_session(train_job)
Пример #20
0
    def _pre_process(self, request):
        """Pre-process before dealing with command."""

        # check if version is mismatch, if mismatch, send mismatch info to UI
        if self._status == ServerStatus.MISMATCH:
            log.warning("Version of Mindspore and Mindinsight re unmatched,"
                        "waiting for user to terminate the script.")
            metadata_stream = self._cache_store.get_stream_handler(
                Streams.METADATA)
            # put metadata into data queue
            metadata = metadata_stream.get(['state', 'debugger_version'])
            self._cache_store.put_data(metadata)
            return

        metadata_stream = self._cache_store.get_stream_handler(
            Streams.METADATA)
        is_new_step = metadata_stream.step < request.cur_step
        is_new_node = metadata_stream.full_name != request.cur_node
        # clean cache data at the beginning of new step or node has been changed.
        if is_new_step or is_new_node:
            self._cache_store.clean_data()
            self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(
                request.cur_step)
        if is_new_step:
            self._cache_store.get_stream_handler(
                Streams.WATCHPOINT_HIT).clean()
        # receive graph at the beginning of the training
        if self._status == ServerStatus.RECEIVE_GRAPH:
            self._send_graph_flag(metadata_stream)
        # receive new metadata
        if is_new_step or is_new_node:
            self._update_metadata(metadata_stream, request)
        self._send_received_tensor_tag()
        self._send_watchpoint_hit_flag()
Пример #21
0
    def search(self, filter_condition):
        """
        Search for single node in graph.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name pattern.
                - graph_name (str): The graph name.
                - watch_point_id (int): The id of watchpoint. Default: 0.
                - node_category (str): The node_category. Default: None

        Returns:
            dict, the searched nodes.
        """
        log.info("receive search request with filter_condition: %s",
                 filter_condition)
        # validate watchpoint id
        watch_point_id = filter_condition.pop('watch_point_id', 0)
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watch_point_id)
        # validate and update graph name
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        filter_condition['graph_name'] = graph_name
        # get searched graph
        graph = graph_stream.search_nodes(filter_condition)
        # add watched label to graph
        watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id,
                                          graph_name)
        return graph
Пример #22
0
    def get_tensor_history(self, node_name, graph_name=None, depth=0):
        """
        Get the tensor history of a specified node.

        Args:
            node_name (str): The debug name of the node.
            graph_name (str): The graph_name. Default: None.
            depth (int): The number of layers the user
                wants to trace. Default is 0.

        Returns:
            dict, basic tensor history, only including tensor name and tensor type and node type.
        """
        graph_name, node_name = self._parse_node_name(node_name, graph_name)
        graph = self._get_graph(graph_name=graph_name, node_name=node_name)
        # validate node type, scope node has no tensor history
        node_type = graph.get_node_type(node_name)
        if is_scope_type(node_type):
            log.error("Scope type node has no tensor history.")
            raise DebuggerParamValueError("Invalid leaf node name.")
        # get tensor history
        tensor_history, cur_outputs_nums = graph.get_tensor_history(
            node_name, depth)
        # add the tensor type for tensor history
        self._update_tensor_history(tensor_history[0:cur_outputs_nums],
                                    'output', graph_name)
        self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input',
                                    graph_name)
        log.debug("Get %d tensors in tensor history for node <%s>.",
                  len(tensor_history), node_name)
        return {'tensor_history': tensor_history}
Пример #23
0
    def retrieve_tensor_value(self,
                              name,
                              detail,
                              shape,
                              graph_name=None,
                              prev=False):
        """Retrieve the tensor value."""
        log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s",
                 name, detail, shape)
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name)
        reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({
            'name':
            tensor_name,
            'node_type':
            node_type,
            'shape':
            parsed_shape,
            'prev':
            prev
        })
        reply['tensor_value']['name'] = name

        return reply
Пример #24
0
    def put(self, value):
        """
        Put value into graph cache. Called by grpc server.

        Args:
            value (GraphProto): The Graph proto message.
        """
        log.info("Put graph into cache.")
        sorted_value_list = self._sort_graph(value)
        for graph_name, graph_value in sorted_value_list:
            self._graph_proto[graph_name] = graph_value
            # build sub graph
            graph = DebuggerGraph()
            graph.build_graph(graph_value)
            self._graph[graph_name] = graph
            self.bfs_order.extend(graph.get_bfs_order())
            leaf_nodes = graph.leaf_nodes
            self._all_leaf_nodes.update(leaf_nodes)
            for _, node in leaf_nodes.items():
                self.graph_node_map[node.full_name] = graph_name

        # build whole graph
        graph = DebuggerMultiGraph()
        graph.add_graph(self._graph)
        self._whole_graph = graph
Пример #25
0
    def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False):
        """
        Get the graph of the next node according to node_name.

        Args:
            node_name (str): The name of current chosen leaf node.
            graph_name (str): The graph name.
            ascend (bool): If True, traverse the input nodes;
                If False, traverse the output nodes. Default is True.

        Returns:
            dict, the next node information.
        """
        log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name,
                 ascend)
        reply = {}
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(graph_name)
        next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
        # no next node
        if next_node_name is None:
            return reply
        # add graph and tensor history for next node
        filter_condition = {
            'name': next_node_name,
            'graph_name': graph_name,
            'single_node': True
        }
        search_graph = self._get_nodes_info(filter_condition)
        reply = {'name': next_node_name}
        reply.update(search_graph)

        return reply
Пример #26
0
    def get(self, filter_condition=None):
        """
        Get full tensor value.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The full name of tensor.
                - node_type (str): The type of the node.
                - prev (bool): Whether to get previous tensor.

        Returns:
            dict, the tensor_value.
        """
        name = filter_condition.get('name')
        node_type = filter_condition.get('node_type')
        shape = filter_condition.get('shape')
        if filter_condition.get('prev'):
            step = self.prev_step
        else:
            step = self.cur_step
        tensor = self._get_tensor(name, node_type, step)
        if not tensor:
            log.error("No tensor named %s at the step %s", name, step)
            raise DebuggerParamValueError("No tensor named {}".format(name))
        tensor_info = tensor.get_full_info(shape)
        self._update_has_prev_step_field(tensor_info, name, node_type)
        return {'tensor_value': tensor_info}
Пример #27
0
    def put(self, value):
        """
        Put value into event_cache.

        Args:
            value (dict): The event to be put into cache.
        """
        if not isinstance(value, dict):
            log.error("Dict type required when put event message.")
            raise DebuggerParamValueError(
                "Dict type required when put event message.")

        with self._lock:
            log.debug(
                "Put the %d-th message into queue. \n %d requests is waiting.",
                self._next_idx, len(self._pending_requests))
            cur_pos = self._next_idx
            # update next pos
            self._next_idx += 1
            if self._next_idx >= self.max_limit:
                self._next_idx = 0
                self._prev_flag = self._cur_flag
                self._cur_flag = str(uuid.uuid4())
            # set next pos
            if not value.get('metadata'):
                value['metadata'] = {}
            value['metadata']['pos'] = self.next_pos
            self._event_cache[cur_pos] = value
            # feed the value for pending requests
            self.clean_pending_requests(value)
Пример #28
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.
                - tensor_proto (TensorProto): The tensor proto.
                - tensor_contents (list[byte]): The list of tensor content values.

        Returns:
            bool, the tensor has updated successfully.
        """
        tensor_proto = value.get('tensor_proto')
        tensor_proto.ClearField('tensor_content')
        step = value.get('step', 0)
        if tensor_proto.iter and step > 0:
            log.debug("Received previous tensor.")
            step -= 1
        tensor_content = b''.join(value.get('tensor_contents'))
        tensor = OpTensor(tensor_proto, tensor_content, step)
        flag = self._put_tensor_into_cache(tensor, step)
        log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag)
        return flag
Пример #29
0
    def _get_watch_names_by_search(self, search_nodes, target_node_name):
        """
        Get watch names according to search results.

        Args:
            search_nodes (dict): Search result.
                The format is like {'nodes': [<Search Node>]}. The <Search Node> format is like
                {'name': <UI node name>, 'type': <node type>, 'nodes': [<Search Node>]}
            target_node_name (str): Node name for UI.

        Returns:
            set[str], collection of names.
        """
        names = set()
        tmp_queue = Queue()
        tmp_queue.put(search_nodes)
        while not tmp_queue.empty():
            cur_node = tmp_queue.get()
            for node in cur_node.get('nodes'):
                node_name = node.get('name')
                if not target_node_name.startswith(node_name) or is_cst_type(
                        node.get('type')):
                    continue
                if target_node_name == node_name:
                    self._add_leaf_node_collection(node, names)
                    return names
                tmp_queue.put(node)
        # the target node name is not in search nodes.
        log.debug("node %s is not in search nodes.")
        names.add(target_node_name)
        return names
Пример #30
0
    def get_node_name_by_full_name(self, full_name):
        """Get node name by full names."""
        inner_name = self._full_name_map_name.get(full_name, '')
        if not inner_name:
            log.warning("Node %s does not find the relative inner node name.", full_name)

        return inner_name