示例#1
0
 def __init__(self):
     self.condition_mgr = ConditionMgr()
     self.cache_store = DebuggerCache()
     self.grpc_server = DebuggerGrpcServer(self.cache_store,
                                           self.condition_mgr)
     self.grpc_server_manager = None
     self.back_server = None
示例#2
0
 def __init__(self, grpc_port=None):
     self.grpc_port = grpc_port
     self.cache_store = DebuggerCache()
     self.grpc_server = DebuggerGrpcServer(self.cache_store)
     self.grpc_server_manager = None
     self.back_server = None
     self._watch_point_id = 0
示例#3
0
 def test_get_dbg_online_server(self):
     """Get debugger online server"""
     context = DebuggerServerContext(dbg_mode='online')
     server_obj = self._dbg_server_factory.get_debugger_server(
         DebuggerCache(), context)
     server_obj.start()
     server_obj.stop()
示例#4
0
 def test_get_dbg_offline_server(self, mock_import):
     """Get debugger offline server"""
     mock_import.return_value = mock_dbg_services
     context = DebuggerServerContext(dbg_mode='offline',
                                     dbg_dir=self._dbg_dir)
     server_obj = self._dbg_server_factory.get_debugger_server(
         DebuggerCache(), context)
     server_obj.start()
     server_obj.stop()
示例#5
0
 def setup_method(self):
     """Prepare debugger server object."""
     cache_store = DebuggerCache()
     cache_store.initialize()
     self._server = TrainingControlOperator(cache_store)
 def setup_method(self):
     """Initialize for each testcase."""
     cache_store = DebuggerCache()
     self._server = DebuggerGrpcServer(cache_store,
                                       condition_mgr=ConditionMgr())
示例#7
0
class DebuggerServer:
    """The server manager of debugger."""
    def __init__(self):
        self.condition_mgr = ConditionMgr()
        self.cache_store = DebuggerCache()
        self.grpc_server = DebuggerGrpcServer(self.cache_store,
                                              self.condition_mgr)
        self.grpc_server_manager = None
        self.back_server = None

    def get_condition_collections(self, train_id):
        """Get default condition_collections"""
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        condition_context = ConditionContext(metadata_stream.backend,
                                             metadata_stream.step)
        log.debug("Train_id: %s, backend: %s", train_id,
                  condition_context.backend)
        return self.condition_mgr.get_all_collections(condition_context)

    def set_recommended_watch_points(self, set_recommended, train_id):
        """set recommended watch points."""
        if not isinstance(set_recommended, bool):
            log.error("Bool param should be given for set_recommended")
            raise DebuggerParamValueError("Bool param should be given.")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.recommendation_confirmed:
            log.error("User has confirmed setting recommended watchpoints")
            raise DebuggerSetRecommendWatchpointsError()

        metadata_stream.recommendation_confirmed = True
        condition_context = ConditionContext(metadata_stream.backend,
                                             metadata_stream.step)
        log.debug("Train_id: %s, backend: %s", train_id,
                  condition_context.backend)
        res = metadata_stream.get(['state', 'enable_recheck'])
        if set_recommended:
            res['id'] = self._add_recommended_watchpoints(condition_context)

        return res

    def _add_recommended_watchpoints(self, condition_context):
        """Add predefined watchpoints."""
        log.debug("Add predefined watchpoints.")
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        watchpoints = recommend_watchpoints(self.condition_mgr, graph_stream,
                                            condition_context)
        watch_point_stream_handler = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watch_points_ids = []
        for watchpoint in watchpoints:
            watch_points_id = watch_point_stream_handler.create_watchpoint(
                watch_condition=watchpoint.get_watch_condition_dict(),
                watch_nodes=watchpoint.watch_nodes,
                name=watchpoint.name,
                condition_mgr=self.condition_mgr)
            watch_points_ids.append(watch_points_id)
        return watch_points_ids

    def start(self):
        """Start server."""
        grpc_port = settings.DEBUGGER_PORT if hasattr(
            settings, 'DEBUGGER_PORT') else 50051
        host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
        hostname = "{}:{}".format(host, grpc_port)
        # initialize a grpc server
        grpc_server_manager = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        grpc_server_base.add_EventListenerServicer_to_server(
            self.grpc_server, grpc_server_manager)
        grpc_server_manager.add_insecure_port(hostname)
        grpc_server_manager.start()
        my_server_thread = Thread(
            target=grpc_server_manager.wait_for_termination)
        # start grpc server
        my_server_thread.start()
        self.back_server = my_server_thread
        self.grpc_server_manager = grpc_server_manager
        # register stop server handler
        signal.signal(signal.SIGINT, self._stop_handler)
        log.info("Start grpc server %s", hostname)

    def _stop_handler(self, signum, frame):
        """Register stop server handler."""
        self.stop()
        log.debug("Deal with stop signal: %s, %s", signum, frame)

    def stop(self):
        """Stop debugger server."""
        log.info("Send terminate info to client.")
        self.control({'mode': 'terminate'})
        self.grpc_server_manager.stop(grace=None)
        self.back_server.join()
        log.info("Stop debugger server.")

    def poll_data(self, pos):
        """
        Get the pos-th data from DebuggerCache.

        Args:
            pos (int): The index of data.

        Returns:
            dict, the data to be updated.
        """
        if not isinstance(pos, str):
            log.error("Pos should be string. Received: %s", pos)
            raise DebuggerParamValueError("Pos should be string.")

        reply = self.cache_store.get_data(pos)

        return reply

    def search(self, filter_condition):
        """
        Search for single node in graph.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name pattern.
                - graph_name (str): The graph name.
                - watch_point_id (int): The id of watchpoint. Default: 0.
                - node_category (str): The node_category. Default: None

        Returns:
            dict, the searched nodes.
        """
        log.info("receive search request with filter_condition: %s",
                 filter_condition)
        # validate watchpoint id
        watch_point_id = filter_condition.pop('watch_point_id', 0)
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watch_point_id)
        # validate and update graph name
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        filter_condition['graph_name'] = graph_name
        # get searched graph
        graph = graph_stream.search_nodes(filter_condition)
        # add watched label to graph
        watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id,
                                          graph_name)
        return graph

    def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
        """
        Get tensor comparisons data for given name, detail, shape and tolerance.

        Args:
            name (str): The name of tensor for ui.
            detail (str): Specify which data to query. Current available value is 'data' which means
                          concrete tensor data. Histogram or unique count can be supported in the future.
            shape (str): Specify concrete dimensions of shape.
            tolerance (str): Specify tolerance of difference between current step tensor and previous
                             step tensor. Default value is 0.

        Raises:
            DebuggerParamValueError, If node type is not parameter or value of detail is not support.
            DebuggerCompareTensorError, If MindSpore is not in waiting state.
        Returns:
            dict, the retrieved data.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
            raise DebuggerCompareTensorError(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name)
        tolerance = to_float(tolerance, 'tolerance')
        tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
        if node_type == NodeTypeEnum.PARAMETER.value:
            reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape,
                                                   tolerance)
        else:
            raise DebuggerParamValueError(
                "The node type must be parameter, but got {}.".format(
                    node_type))
        return reply

    def retrieve(self, mode, filter_condition=None):
        """
        Retrieve data according to mode and params.

        Args:
            mode (str): The type of info message.
            filter_condition (dict): The filter condition.

        Returns:
            dict, the retrieved data.
        """
        log.info(
            "receive retrieve request for mode:%s\n, filter_condition: %s",
            mode, filter_condition)
        mode_mapping = {
            'all': self._retrieve_all,
            'node': self._retrieve_node,
            'watchpoint': self._retrieve_watchpoint,
        }
        # validate param <mode>
        if mode not in mode_mapping.keys():
            log.error(
                "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
                "'watchpoint_hit'], but got %s.", mode_mapping)
            raise DebuggerParamValueError("Invalid mode.")
        # validate backend status
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        filter_condition = {} if filter_condition is None else filter_condition
        reply = mode_mapping[mode](filter_condition)

        return reply

    def _retrieve_all(self, filter_condition=None):
        """Retrieve metadata, root graph and watchpoint list."""
        if filter_condition:
            log.error("No filter condition required for retrieve all request.")
            raise DebuggerParamTypeError("filter_condition should be empty.")
        self.cache_store.clean_data()
        log.info("Clean data queue cache when retrieve all request.")
        result = {}
        for stream in [Streams.METADATA, Streams.GRAPH]:
            sub_res = self.cache_store.get_stream_handler(stream).get()
            result.update(sub_res)

        sub_res = self._hide_parameters_for_ui()
        result.update(sub_res)

        return result

    def _retrieve_node(self, filter_condition):
        """
        Retrieve node info.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.
                - graph_name (str): The relative graph_name of the node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.
                - watch_point_id (int): The id of watchpoint.

        Returns:
            dict, reply with graph.
        """
        log.debug("Retrieve node %s.", filter_condition)
        # validate node name
        node_name = filter_condition.get('name')
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        if node_name:
            # validate node name
            graph_stream.get_node_type(node_name, graph_name)
        filter_condition['single_node'] = bool(
            filter_condition.get('single_node'))
        filter_condition['graph_name'] = graph_name
        reply = self._get_nodes_info(filter_condition)
        return reply

    def _get_nodes_info(self, filter_condition):
        """
        Get nodes info.

        Args:
            filter_condition (dict): The filter condition.

                - name (str): The node name.
                - graph_name (str): The relative graph_name of the node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.
                - watch_point_id (int): The id of watchpoint.

        Returns:
            dict, reply with graph.
        """
        # validate watch_point_id
        watch_point_id = filter_condition.get('watch_point_id', 0)
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watch_point_id)
        # get graph
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        reply = graph_stream.get(filter_condition)
        graph = reply.get('graph')
        # add watched label to graph
        watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id,
                                          filter_condition.get('graph_name'))
        return reply

    def retrieve_tensor_history(self, node_name, graph_name=None):
        """
        Retrieve tensor history for leaf node.

        Args:
            node_name (str): The name of leaf node.
            graph_name (str): The graph name. Default: None.

        Returns:
            dict, the tensor history and metadata.
        """
        log.info("Retrieve tensor history for node: %s.", node_name)
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get(['state', 'step'])
        res = self._get_tensor_history(node_name, graph_name)
        return res

    def _get_tensor_history(self, node_name, graph_name=None):
        """
        Get tensor history for single node.

        Args:
            node_name (str): The name of leaf node.
            graph_name (str): The graph name. Default: None.

        Returns:
            dict, the tensor history and metadata.
        """
        # get basic tensor history
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
        # add tensor value for tensor history
        self._add_tensor_value_for_tensor_history(tensor_history, node_name,
                                                  graph_name)
        # add hit label for tensor history
        watchpoint_hit_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT)
        watchpoint_hit_stream.update_tensor_history(tensor_history)
        # add metadata
        metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(
            ['state', 'step'])
        tensor_history.update(metadata)
        return tensor_history

    def _add_tensor_value_for_tensor_history(self, tensor_history, node_name,
                                             graph_name):
        """
        Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

        Args:
            tensor_history (list[dict]): A list of tensor info, including name and type.
            node_name (str): The UI node name.
            graph_name (str): The graph name. Default: None.

        Returns:
            dict, the tensor info.
        """
        tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
        missed_tensors = tensor_stream.update_tensor_history(tensor_history)
        if missed_tensors:
            view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'node_name': node_name,
                'graph_name': graph_name
            })
            log.debug("Send view cmd.")

    def retrieve_tensor_value(self,
                              name,
                              detail,
                              shape,
                              graph_name=None,
                              prev=False):
        """Retrieve the tensor value."""
        log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s",
                 name, detail, shape)
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name)
        reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({
            'name':
            tensor_name,
            'node_type':
            node_type,
            'shape':
            parsed_shape,
            'prev':
            prev
        })
        reply['tensor_value']['name'] = name

        return reply

    def _get_tensor_name_and_type_by_ui_name(self, name, graph_name=None):
        """
        Get inner tensor name and type by UI name.

        Args:
            name (str): Node name shown in UI.
            graph_name (Union[str, None]): The graph name, default is: None.

        Returns:
            str, full name of tensor.
            str, node type of tensor.
        """
        node_name, slot = name.rsplit(':', 1)
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(
            node_name)
        node_type = graph_stream.get_node_type(node_name, graph_name)
        full_name = graph_stream.get_full_name(node_name, graph_name)
        tensor_name = full_name + ':' + slot
        return node_type, tensor_name

    @staticmethod
    def validate_tensor_param(name, detail):
        """Validate params for retrieve tensor request."""
        # validate name
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        # validate data
        if detail != 'data':
            log.error("Invalid detail value. Received: %s", detail)
            raise DebuggerParamValueError("Invalid detail value.")

    def _retrieve_watchpoint(self, filter_condition):
        """
        Retrieve watchpoint.

        Args:
            filter_condition (dict): Filter condition.

                - watch_point_id (int):  The id of watchpoint. If not given, return all watchpoints.
                - name (str): The name of single node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        watchpoint_id = filter_condition.get('watch_point_id', 0)
        if not watchpoint_id:
            reply = self._hide_parameters_for_ui()
            log.debug("Get condition of watchpoints.")
        else:
            reply = self._retrieve_node(filter_condition)
            log.debug("Get graph of %d-th watchpoint.", watchpoint_id)

        return reply

    def search_watchpoint_hits(self, group_condition):
        """
        Retrieve watchpoint hit.

        Args:
            group_condition (dict): Filter condition.

                - limit (int): The limit of each page.
                - offset (int): The offset of current page.
                - node_name (str): The retrieved node name.
                - graph_name (str): The retrieved graph name.

        Returns:
            dict, watch point list or relative graph.
        """
        if not isinstance(group_condition, dict):
            log.error(
                "Group condition for watchpoint-hits request should be a dict")
            raise DebuggerParamTypeError(
                "Group condition for watchpoint-hits request should be a dict")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        reply = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT).group_by(group_condition)
        reply['outdated'] = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).is_recheckable()
        return reply

    def create_watchpoint(self, params):
        """
        Create watchpoint.

        Args:
            params (dict): Params for create watchpoint.

                - watch_condition (dict): The watch condition. The format is like:
                    {
                        "id": "tensor_too_large",
                        "params": [
                            {
                                "name": "abs_mean_gt",
                                "value": 1.1
                            }
                        ]
                    }

                    - id (str): Id of condition.
                    - params (list[dict]): The list of param for this condition.
                - watch_nodes (list[str]): The list of node names.
                - watch_point_id (int): The id of watchpoint.
                - search_pattern (dict): The search pattern.
                - graph_name (str): The relative graph_name of the watched node.

        Returns:
            dict, the id of new watchpoint and metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.create_watchpoint(params)

    def update_watchpoint(self, params):
        """
        Update watchpoint.

        Args:
            params (dict): Params for update watchpoint.

                - watch_point_id (int): The id of watchpoint.
                - watch_nodes (list[str]): The list of node names.
                - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
                    1 for add nodes to watch nodes.
                - search_pattern (dict): The search pattern.
                - graph_name (str): The relative graph_name of the watched node.

        Returns:
            dict, the metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.update_watchpoint(params)

    def delete_watchpoint(self, watch_point_id=None):
        """
        Delete watchpoint.

        Args:
            watch_point_id (Union[None, int]): The id of watchpoint.
                If None, delete all watchpoints. Default: None.

        Returns:
            dict, the metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id)

    @try_except
    def control(self, params=None):
        """
        Control the training process.

        Args:
            params (dict): The control params.

                - mode (str): Acceptable control command, including `continue`,
                    `pause` and `terminate`.
                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.
                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.
                - name (str): Specify the name of the node. Used when `level` is `node`.
                - graph_name (str): The graph name.

        Returns:
            dict, the response.
        """
        log.info("Receive control request: %s.", params)
        mode = params.pop('mode', None) if params else None
        training_controller = TrainingControlOperator(self.cache_store)
        training_controller.validate_mode(mode)
        return training_controller.control(mode, params)

    def retrieve_node_by_bfs(self, node_name, graph_name=None, ascend=False):
        """
        Get the graph of the next node according to node_name.

        Args:
            node_name (str): The name of current chosen leaf node.
            graph_name (str): The graph name.
            ascend (bool): If True, traverse the input nodes;
                If False, traverse the output nodes. Default is True.

        Returns:
            dict, the next node information.
        """
        log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name,
                 ascend)
        reply = {}
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        graph_name = graph_stream.validate_graph_name(graph_name)
        next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
        # no next node
        if next_node_name is None:
            return reply
        # add graph and tensor history for next node
        filter_condition = {
            'name': next_node_name,
            'graph_name': graph_name,
            'single_node': True
        }
        search_graph = self._get_nodes_info(filter_condition)
        reply = {'name': next_node_name}
        reply.update(search_graph)

        return reply

    @try_except
    def recheck(self):
        """
        Recheck all watchpoints.

        Returns:
            dict, metadata info.
        """
        return TrainingControlOperator(self.cache_store).recheck()

    def retrieve_tensor_graph(self, tensor_name, graph_name):
        """
        Retrieve tensor graph.

        Args:
            tensor_name (str): The tensor name from UI.
            graph_name (str): The graph name.

        Returns:
            dict, tensor graph object.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to get tensor graph the MindSpore is not in waiting state."
            )
            raise DebuggerTensorGraphError
        log.info("Retrieve tensor graph for %s from %s", tensor_name,
                 graph_name)
        tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(
            tensor_name, graph_name)
        return tensor_graph_ops

    def retrieve_tensor_hits(self, tensor_name, graph_name):
        """
        Retrieve tensor hit information.

        Args:
            tensor_name (str): The tensor name from UI.
            graph_name (str): The graph name.

        Returns:
            dict, tensor hit info.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to get tensor hits as the MindSpore is not in waiting state."
            )
            raise DebuggerTensorHitError
        log.info("Retrieve tensor hits for %s from %s", tensor_name,
                 graph_name)
        watch_points = TensorDetailInfo(
            self.cache_store).get_tensor_watch_points(tensor_name, graph_name)
        return {'watch_points': watch_points}

    def _hide_parameters_for_ui(self):
        """
        Hide some parameters on ui.

        Returns:
            dict, watch point list.
        """
        reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
        watch_points = reply.get('watch_points')
        for i, watch_point in enumerate(watch_points):
            watch_condition = watch_point.get('watch_condition')
            parameters = watch_condition.get('params')
            watch_condition_id = watch_condition.get('id')
            mgr_condition = self.condition_mgr.get_condition(
                watch_condition_id)
            ui_watch_condition = []
            for param in parameters:
                parameter_definition = mgr_condition.get_parameter_definition(
                    param['name'])
                if not parameter_definition.visible_on_ui:
                    continue
                ui_watch_condition.append(param)
            reply['watch_points'][i]['watch_condition'][
                'params'] = ui_watch_condition
        return reply
 def __init__(self, context):
     self.condition_mgr = ConditionMgr()
     self.cache_store = DebuggerCache()
     self.context = context
     self.back_server = DebuggerServerFactory().get_debugger_server(
         self.cache_store, context)
class DebuggerSession:
    """The server manager of debugger."""
    def __init__(self, context):
        self.condition_mgr = ConditionMgr()
        self.cache_store = DebuggerCache()
        self.context = context
        self.back_server = DebuggerServerFactory().get_debugger_server(
            self.cache_store, context)

    @property
    def train_job(self):
        """The property of train job."""
        return self.context.train_job

    def get_condition_collections(self, train_id=""):
        """Get default condition_collections"""
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        condition_context = ConditionContext(metadata_stream.backend,
                                             metadata_stream.step)
        log.debug("Train_id: %s, backend: %s", train_id,
                  condition_context.backend)
        return self.condition_mgr.get_all_collections(condition_context)

    def set_recommended_watch_points(self, set_recommended, train_id=""):
        """Set recommended watch points."""
        if not isinstance(set_recommended, bool):
            log.error("Bool param should be given for set_recommended")
            raise DebuggerParamValueError("Bool param should be given.")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.recommendation_confirmed:
            log.error("User has confirmed setting recommended watchpoints")
            raise DebuggerSetRecommendWatchpointsError()

        metadata_stream.recommendation_confirmed = True
        condition_context = ConditionContext(metadata_stream.backend,
                                             metadata_stream.step)
        log.debug("Train_id: %s, backend: %s", train_id,
                  condition_context.backend)
        res = metadata_stream.get(['state', 'enable_recheck'])
        if set_recommended:
            res['id'] = self._add_recommended_watchpoints(condition_context)

        return res

    def _add_recommended_watchpoints(self, condition_context):
        """Add predefined watchpoints."""
        log.debug("Add predefined watchpoints.")
        multi_card_graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH)
        watchpoints = recommend_watchpoints(self.condition_mgr,
                                            multi_card_graph_stream,
                                            condition_context)
        watch_point_stream_handler = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        device_stream = self.cache_store.get_stream_handler(Streams.DEVICE)
        watch_points_ids = []
        for watchpoint in watchpoints:
            watch_points_id = watch_point_stream_handler.create_watchpoint(
                watch_condition=watchpoint.get_watch_condition_dict(),
                watch_nodes=watchpoint.watch_nodes,
                name=watchpoint.name,
                condition_mgr=self.condition_mgr,
                device_amount=device_stream.device_amount)
            watch_points_ids.append(watch_points_id)
        return watch_points_ids

    def start(self):
        """Start server."""
        self.back_server.start()
        log.info("Start debugger backend server.")

    def _stop_handler(self, signum, frame):
        """Register stop server handler."""
        self.stop()
        log.debug("Deal with stop signal: %s, %s", signum, frame)

    def stop(self):
        """Stop debugger server."""
        log.info("Send terminate info to client.")
        self.control({'mode': 'terminate'})
        self.back_server.stop()
        log.info("Stop debugger server.")

    def poll_data(self, pos):
        """
        Get the pos-th data from DebuggerCache.

        Args:
            pos (int): The index of data.

        Returns:
            dict, the data to be updated.
        """
        if not isinstance(pos, str):
            log.error("Pos should be string. Received: %s", pos)
            raise DebuggerParamValueError("Pos should be string.")

        reply = self.cache_store.get_data(pos)

        return reply

    def search(self, filter_condition):
        """
        Search for single node in graph.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name pattern.
                - graph_name (str): The graph name.
                - watch_point_id (int): The id of watchpoint. Default: 0.
                - node_category (str): The node_category. Default: None
                - rank_id (int): The id of rank. Default: 0.
                - stack_pattern (str): The pattern of stack info. Default: None.

        Returns:
            dict, the searched nodes.
        """
        log.info("receive search request with filter_condition: %s",
                 filter_condition)
        # validate watchpoint id
        watch_point_id = filter_condition.pop('watch_point_id', 0)
        rank_id = filter_condition.pop('rank_id', 0)
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watch_point_id)
        # validate and update graph name
        graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        filter_condition['graph_name'] = graph_name
        # get searched graph
        graph = graph_stream.search_nodes(filter_condition)
        # add watched label to graph
        watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id,
                                          graph_name, rank_id)
        return graph

    def tensor_comparisons(self,
                           name,
                           shape,
                           detail='data',
                           tolerance='0',
                           rank_id=0,
                           graph_name=None):
        """
        Get tensor comparisons data for given name, detail, shape and tolerance.

        Args:
            name (str): The name of tensor for ui.
            shape (str): Specify concrete dimensions of shape.
            detail (str): Specify which data to query. Current available value is 'data' which means
                          concrete tensor data. Histogram or unique count can be supported in the future.
            rank_id (int): The id of rank. Default: 0.
            tolerance (str): Specify tolerance of difference between current step tensor and previous
                             step tensor. Default value is 0.
            graph_name (str): The graph name. Default: None.

        Returns:
            dict, the retrieved data.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
            raise DebuggerCompareTensorError(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        tolerance = to_float(tolerance, 'tolerance')
        tensor_stream = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
        cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
        if node_type == NodeTypeEnum.PARAMETER.value:
            reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape,
                                                   tolerance, cur_step)
        else:
            raise DebuggerParamValueError(
                "The node type must be parameter, but got {}.".format(
                    node_type))
        if reply.pop('view_cmd', False):
            self._send_view_cmd(name, graph_name, rank_id, tensor_name,
                                node_type)
        return reply

    def retrieve(self, mode, filter_condition=None):
        """
        Retrieve data according to mode and params.

        Args:
            mode (str): The type of info message.
            filter_condition (dict): The filter condition.

        Returns:
            dict, the retrieved data.
        """
        log.info(
            "receive retrieve request for mode:%s\n, filter_condition: %s",
            mode, filter_condition)
        mode_mapping = {
            'all': self._retrieve_all,
            'node': self._retrieve_node,
            'watchpoint': self._retrieve_watchpoint,
        }
        # validate param <mode>
        if mode not in mode_mapping.keys():
            log.error(
                "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
                "'watchpoint_hit'], but got %s.", mode_mapping)
            raise DebuggerParamValueError("Invalid mode.")
        # validate backend status
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        filter_condition = {} if filter_condition is None else filter_condition
        reply = mode_mapping[mode](filter_condition)

        return reply

    def _retrieve_all(self, filter_condition=None):
        """Retrieve metadata, root graph and watchpoint list."""
        if filter_condition:
            log.error("No filter condition required for retrieve all request.")
            raise DebuggerParamTypeError("filter_condition should be empty.")
        self.cache_store.clean_data()
        log.info("Clean data queue cache when retrieve all request.")
        result = {}
        for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]:
            sub_res = self.cache_store.get_stream_handler(stream).get()
            result.update(sub_res)

        devices = result['devices']
        if not devices:
            graph = result['graph']
            metadata = result['metadata']
            device = {
                'rank_id': 0,
                'server_ip': metadata.get('ip', 'localhost'),
                'device_id': metadata.get('device_name', ''),
                'graph_names': graph.get('graph_names', [])
            }
            devices.append(device)
        sub_res = self._hide_parameters_for_ui()
        result.update(sub_res)

        return result

    def _retrieve_node(self, filter_condition):
        """
        Retrieve node info.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.
                - graph_name (str): The relative graph_name of the node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.
                - watch_point_id (int): The id of watchpoint.

        Returns:
            dict, reply with graph.
        """
        log.debug("Retrieve node %s.", filter_condition)
        # validate node name
        node_name = filter_condition.get('name')
        rank_id = filter_condition.get('rank_id', 0)
        graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
        graph_name = graph_stream.validate_graph_name(
            filter_condition.get('graph_name'))
        if node_name:
            # validate node name
            graph_stream.get_node_type(node_name, graph_name)
        filter_condition['single_node'] = bool(
            filter_condition.get('single_node'))
        filter_condition['graph_name'] = graph_name
        reply = self._get_nodes_info(filter_condition)
        return reply

    def _get_nodes_info(self, filter_condition):
        """
        Get nodes info.

        Args:
            filter_condition (dict): The filter condition.

                - name (str): The node name.
                - graph_name (str): The relative graph_name of the node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.
                - watch_point_id (int): The id of watchpoint.

        Returns:
            dict, reply with graph.
        """
        # validate watch_point_id
        rank_id = filter_condition.get('rank_id', 0)
        watch_point_id = filter_condition.get('watch_point_id', 0)
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watch_point_id)
        # get graph
        graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
        reply = graph_stream.get(filter_condition)
        graph = reply.get('graph')
        # add watched label to graph
        watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id,
                                          filter_condition.get('graph_name'),
                                          rank_id)
        return reply

    def retrieve_tensor_history(self, node_name, graph_name=None, rank_id=0):
        """
        Retrieve tensor history for leaf node.

        Args:
            node_name (str): The name of leaf node.
            graph_name (str): The graph name. Default: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            dict, the tensor history and metadata.
        """
        log.info("Retrieve tensor history for node: %s.", node_name)
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get(['state', 'step'])
        res = self._get_tensor_history(node_name, graph_name, rank_id)
        return res

    def _get_tensor_history(self, node_name, graph_name=None, rank_id=0):
        """
        Get tensor history for single node.

        Args:
            node_name (str): The name of leaf node.
            graph_name (str): The graph name. Default: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            dict, the tensor history and metadata.
        """
        # get basic tensor history
        graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
        tensor_history = graph_stream.get_tensor_history(node_name, graph_name)
        # add tensor value for tensor history
        self._add_tensor_value_for_tensor_history(tensor_history, node_name,
                                                  graph_name, rank_id)
        # add hit label for tensor history
        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT).update_tensor_history(
                tensor_history, rank_id)
        # add metadata
        metadata = self.cache_store.get_stream_handler(Streams.METADATA).get(
            ['step'])
        tensor_history.update(metadata)
        return tensor_history

    def _add_tensor_value_for_tensor_history(self, tensor_history, node_name,
                                             graph_name, rank_id):
        """
        Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

        Args:
            tensor_history (list[dict]): A list of tensor info, including name and type.
            node_name (str): The UI node name.
            graph_name (str): The graph name. Default: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            dict, the tensor info.
        """
        tensor_stream = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
        cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
        missed_tensors = tensor_stream.update_tensor_history(
            tensor_history, cur_step)
        if missed_tensors:
            view_cmd = create_view_event_from_tensor_basic_info(missed_tensors)
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'node_name': node_name,
                'graph_name': graph_name,
                'rank_id': rank_id,
                'stats': True
            })
            log.debug("Send view cmd.")

    def retrieve_tensor_value(self,
                              name,
                              detail,
                              shape,
                              graph_name=None,
                              prev=False,
                              rank_id=0):
        """Retrieve the tensor value."""
        log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s",
                 name, detail, shape)
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
            {
                'name': tensor_name,
                'node_type': node_type,
                'shape': parsed_shape,
                'prev': prev
            }, rank_id)
        reply['tensor_value']['name'] = name
        if reply.pop('view_cmd', False):
            self._send_view_cmd(name, graph_name, rank_id, tensor_name,
                                node_type)
        return reply

    def _send_view_cmd(self, name, graph_name, rank_id, tensor_name,
                       node_type):
        """Send view command."""
        tensor_basic_info = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(
                rank_id).get_missing_tensor_info(tensor_name,
                                                 node_type,
                                                 check_cache=True)
        if tensor_basic_info:
            view_cmd = create_view_event_from_tensor_basic_info(
                tensor_basic_info)
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'tensor_name': name,
                'graph_name': graph_name,
                'rank_id': rank_id
            })
            log.debug("Send view cmd.")

    def load(self, name, prev, graph_name=None, rank_id=0):
        """
        Load the tensor value.

        Args:
            name (str): Node name shown in UI.
            prev (bool): The previous step or current step.
            graph_name (Union[str, None]): The graph name, default is: None.
            rank_id (int): The id of rank. Default: 0.
        """
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        node_type, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        log.info("Load the tensor value: name: %s", tensor_name)
        reply = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id).load(
                tensor_name=tensor_name,
                graph_name=graph_name,
                prev=prev,
                node_type=node_type)
        if not reply.get('in_memory'):
            prev_step = 'prev' if prev else ''
            tensor_basic_info = self.cache_store.get_stream_handler(Streams.TENSOR).\
                tensor_basic_info(tensor_name, node_type, prev_step)
            view_cmd = create_view_event_from_tensor_basic_info(
                [tensor_basic_info])
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'node_name': name,
                'graph_name': graph_name,
                'rank_id': rank_id,
                'load': {
                    'tensor_name': tensor_name,
                    'prev': prev,
                    'node_type': node_type
                }
            })
            log.debug("Send view cmd.")
        else:
            metadata = self.cache_store.get_stream_handler(
                Streams.METADATA).get(['step', 'state'])
            ret = {'tensor_file': True, 'node_name': name}
            ret.update(metadata)
            self.cache_store.put_data(ret)
        reply = {'node_name': name}
        return reply

    def download(self, name, prev, graph_name=None, rank_id=0):
        """
        Download the tensor value.

        Args:
            name (str): Node name shown in UI.
            prev (bool): The previous step or current step.
            graph_name (Union[str, None]): The graph name, default is: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            str, the file path.
            str, the file name.
        """
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        _, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        log.info("Download the tensor value: name: %s", tensor_name)
        tensor_stream = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
        step = tensor_stream.cur_step
        if prev:
            step -= 1
        tensor_info = {
            "tensor_name": tensor_name,
            "graph_name": graph_name,
            "step": step,
            "rank_id": rank_id
        }
        return tensor_stream.download_mgr.get(**tensor_info)

    def _get_tensor_name_and_type_by_ui_name(self,
                                             name,
                                             graph_name=None,
                                             rank_id=0):
        """
        Get inner tensor name and type by UI name.

        Args:
            name (str): Node name shown in UI.
            graph_name (Union[str, None]): The graph name, default is: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            str, full name of tensor.
            str, node type of tensor.
        """
        node_name, slot = name.rsplit(':', 1)
        graph_stream = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_graph_handler_by_rank_id(rank_id)
        graph_name = graph_name if graph_name else graph_stream.get_graph_id_by_name(
            node_name)
        graph_name = graph_stream.validate_graph_name(graph_name)
        node_type = graph_stream.get_node_type(node_name, graph_name)
        full_name = graph_stream.get_full_name(node_name, graph_name)
        tensor_name = full_name + ':' + slot
        return node_type, tensor_name, graph_name

    @staticmethod
    def validate_tensor_param(name, detail):
        """Validate params for retrieve tensor request."""
        # validate name
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        # validate data
        if detail != 'data':
            log.error("Invalid detail value. Received: %s", detail)
            raise DebuggerParamValueError("Invalid detail value.")

    def _retrieve_watchpoint(self, filter_condition):
        """
        Retrieve watchpoint.

        Args:
            filter_condition (dict): Filter condition.

                - watch_point_id (int):  The id of watchpoint. If not given, return all watchpoints.
                - name (str): The name of single node.
                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        watchpoint_id = filter_condition.get('watch_point_id', 0)
        if not watchpoint_id:
            reply = self._hide_parameters_for_ui()
            log.debug("Get condition of watchpoints.")
        else:
            reply = self._retrieve_node(filter_condition)
            log.debug("Get graph of %d-th watchpoint.", watchpoint_id)

        return reply

    def search_watchpoint_hits(self, group_condition):
        """
        Retrieve watchpoint hit.

        Args:
            group_condition (dict): Filter condition.

                - limit (int): The limit of each page.
                - offset (int): The offset of current page.
                - node_name (str): The retrieved node name.
                - graph_name (str): The retrieved graph name.
                - rank_id (int): The rank id.

        Returns:
            dict, watch point list or relative graph.
        """
        if not isinstance(group_condition, dict):
            log.error(
                "Group condition for watchpoint-hits request should be a dict")
            raise DebuggerParamTypeError(
                "Group condition for watchpoint-hits request should be a dict")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        rank_id = group_condition.pop('rank_id', 0)
        reply = {}
        multi_watchpoint_hit_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT)
        if multi_watchpoint_hit_stream.check_rank_id(rank_id):
            watchpoint_hit_stream = multi_watchpoint_hit_stream.get_hit_handler_by_rank_id(
                rank_id)
            reply = watchpoint_hit_stream.group_by(group_condition)

        reply['outdated'] = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).is_recheckable()
        return reply

    def create_watchpoint(self, params):
        """
        Create watchpoint.

        Args:
            params (dict): Params for create watchpoint.

                - watch_condition (dict): The watch condition. The format is like:
                    {
                        "id": "tensor_too_large",
                        "params": [
                            {
                                "name": "abs_mean_gt",
                                "value": 1.1
                            }
                        ]
                    }

                    - id (str): Id of condition.
                    - params (list[dict]): The list of param for this condition.
                - watch_nodes (list[str]): The list of node names.
                - watch_point_id (int): The id of watchpoint.
                - search_pattern (dict): The search pattern.
                - graph_name (str): The relative graph_name of the watched node.

        Returns:
            dict, the id of new watchpoint and metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.create_watchpoint(params)

    def update_watchpoint(self, params):
        """
        Update watchpoint.

        Args:
            params (dict): Params for update watchpoint.

                - watch_point_id (int): The id of watchpoint.
                - watch_nodes (list[str]): The list of node names.
                - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
                    1 for add nodes to watch nodes.
                - search_pattern (dict): The search pattern.
                - graph_name (str): The relative graph_name of the watched node.

        Returns:
            dict, the metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.update_watchpoint(params)

    def delete_watchpoint(self, watch_point_id=None):
        """
        Delete watchpoint.

        Args:
            watch_point_id (Union[None, int]): The id of watchpoint.
                If None, delete all watchpoints. Default: None.

        Returns:
            dict, the metadata info.
        """
        watchpoint_opt = WatchpointOperator(self.cache_store,
                                            self.condition_mgr)
        return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id)

    @try_except
    def control(self, params=None):
        """
        Control the training process.

        Args:
            params (dict): The control params.

                - mode (str): Acceptable control command, including `continue`,
                    `pause` and `terminate`.
                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.
                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.
                - name (str): Specify the name of the node. Used when `level` is `node`.
                - graph_name (str): The graph name.

        Returns:
            dict, the response.
        """
        log.info("Receive control request: %s.", params)
        mode = params.pop('mode', None) if params else None
        training_controller = TrainingControlOperator(self.cache_store)
        training_controller.validate_mode(mode)
        return training_controller.control(mode, params)

    @try_except
    def recheck(self):
        """
        Recheck all watchpoints.

        Returns:
            dict, metadata info.
        """
        return TrainingControlOperator(self.cache_store).recheck()

    def retrieve_tensor_graph(self, tensor_name, graph_name, rank_id=0):
        """
        Retrieve tensor graph.

        Args:
            tensor_name (str): The tensor name from UI.
            graph_name (str): The graph name.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            dict, tensor graph object.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to get tensor graph the MindSpore is not in waiting state."
            )
            raise DebuggerTensorGraphError
        log.info("Retrieve tensor graph for %s from %s", tensor_name,
                 graph_name)
        tensor_graph_ops = TensorDetailInfo(self.cache_store).get_tensor_graph(
            tensor_name, graph_name, rank_id)
        return tensor_graph_ops

    def retrieve_tensor_hits(self, tensor_name, graph_name, rank_id=0):
        """
        Retrieve tensor hit information.

        Args:
            tensor_name (str): The tensor name from UI.
            graph_name (str): The graph name.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            dict, tensor hit info.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to get tensor hits as the MindSpore is not in waiting state."
            )
            raise DebuggerTensorHitError
        log.info("Retrieve tensor hits for %s from %s", tensor_name,
                 graph_name)
        watch_points = TensorDetailInfo(
            self.cache_store).get_tensor_watch_points(tensor_name, graph_name,
                                                      rank_id)
        return {'watch_points': watch_points}

    def _hide_parameters_for_ui(self):
        """
        Hide some parameters on ui.

        Returns:
            dict, watch point list.
        """
        reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get()
        watch_points = reply.get('watch_points')
        for i, watch_point in enumerate(watch_points):
            watch_condition = watch_point.get('watch_condition')
            parameters = watch_condition.get('params')
            watch_condition_id = watch_condition.get('id')
            mgr_condition = self.condition_mgr.get_condition(
                watch_condition_id)
            ui_watch_condition = []
            for param in parameters:
                parameter_definition = mgr_condition.get_parameter_definition(
                    param['name'])
                if not parameter_definition.visible_on_ui:
                    continue
                ui_watch_condition.append(param)
            reply['watch_points'][i]['watch_condition'][
                'params'] = ui_watch_condition
        return reply

    def get_stack_infos(self, filter_condition):
        """
        Get stack infos.

        Args:
            filter_condition (dict): The filter condition to query stack infos.

                - pattern (str): The pattern of stack infos.
                - limit (int): The size of each page.
                - offset (int): The index of the page. Valid only when `limit` is not 0.

        Returns:
            dict, the stack info object.
        """
        source_handler = self.cache_store.get_stream_handler(
            Streams.GRAPH).source_handler
        res = source_handler.get_stack_info_by_offset(
            pattern=filter_condition.get('pattern'),
            limit=filter_condition.get('limit', 0),
            offset=filter_condition.get('offset', 0))
        return res
示例#10
0
class DebuggerServer:
    """The server manager of debugger."""
    def __init__(self, grpc_port=None):
        self.grpc_port = grpc_port
        self.cache_store = DebuggerCache()
        self.grpc_server = DebuggerGrpcServer(self.cache_store)
        self.grpc_server_manager = None
        self.back_server = None
        self._watch_point_id = 0

    def start(self):
        """Start server."""
        grpc_port = self.grpc_port if self.grpc_port else "50051"
        host = settings.HOST if hasattr(settings, 'HOST') else '[::]'
        hostname = "{}:{}".format(host, grpc_port)
        # initialize a grpc server
        grpc_server_manager = grpc.server(
            futures.ThreadPoolExecutor(max_workers=10))
        grpc_server_base.add_EventListenerServicer_to_server(
            self.grpc_server, grpc_server_manager)
        grpc_server_manager.add_insecure_port(hostname)
        grpc_server_manager.start()
        my_server_thread = Thread(
            target=grpc_server_manager.wait_for_termination)
        # start grpc server
        my_server_thread.start()
        self.back_server = my_server_thread
        self.grpc_server_manager = grpc_server_manager
        # register stop server handler
        signal.signal(signal.SIGINT, self._stop_handler)
        log.info("Start grpc server %s", hostname)

    def _stop_handler(self, signum, frame):
        """Register stop server handler."""
        self.stop()
        log.debug("Deal with stop signal: %s, %s", signum, frame)

    def stop(self):
        """Stop debugger server."""
        self.grpc_server_manager.stop(grace=None)
        self.back_server.join()
        log.info("Stop debugger server.")

    def poll_data(self, pos):
        """
        Get the pos-th data from DebuggerCache.

        Args:
            pos (int): The index of data.

        Returns:
            dict, the data to be updated.
        """
        if not isinstance(pos, str):
            log.error("Pos should be string. Received: %s", pos)
            raise DebuggerParamValueError("Pos should be string.")

        reply = self.cache_store.get_data(pos)

        return reply

    def search(self, name, watch_point_id):
        """Search for single node in graph."""
        log.info("receive search request for node:%s, in watchpoint:%d", name,
                 watch_point_id)
        graph = self.cache_store.get_stream_handler(
            Streams.GRAPH).search_nodes(name)
        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).set_watch_nodes(graph, watch_point_id)
        return graph

    def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
        """
        Get tensor comparisons data for given name, detail, shape and tolerance.

        Args:
            name (str): The name of tensor for ui.
            detail (str): Specify which data to query. Current available value is 'data' which means
                          concrete tensor data. Histogram or unique count can be supported in the future.
            shape (str): Specify concrete dimensions of shape.
            tolerance (str): Specify tolerance of difference between current step tensor and previous
                             step tensor. Default value is 0.

        Raises:
            DebuggerParamValueError, If node type is not parameter or value of detail is not support.
            DebuggerCompareTensorError, If MindSpore is not in waiting state.
        Returns:
            dict, the retrieved data.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
            raise DebuggerCompareTensorError(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
        self.validate_tensor_param(name, detail)
        parsed_shape = self.parse_shape(shape)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name)
        tolerance = to_float(tolerance, 'tolerance')
        tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
        if detail == 'data':
            if node_type == NodeTypeEnum.PARAMETER.value:
                reply = tensor_stream.get_tensors_diff(tensor_name,
                                                       parsed_shape, tolerance)
            else:
                raise DebuggerParamValueError(
                    "The node type must be parameter, but got {}.".format(
                        node_type))
        else:
            raise DebuggerParamValueError(
                "The value of detail: {} is not support.".format(detail))
        return reply

    def retrieve(self, mode, filter_condition=None):
        """
        Retrieve data according to mode and params.

        Args:
            mode (str): The type of info message.
            filter_condition (dict): The filter condition.

        Returns:
            dict, the retrieved data.
        """
        log.info(
            "receive retrieve request for mode:%s\n, filter_condition: %s",
            mode, filter_condition)
        # validate watchpoint_id

        mode_mapping = {
            'all': self._retrieve_all,
            'node': self._retrieve_node,
            'watchpoint': self._retrieve_watchpoint,
            'watchpoint_hit': self._retrieve_watchpoint_hit
        }
        # validate param <mode>
        if mode not in mode_mapping.keys():
            log.error(
                "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
                "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping)
            raise DebuggerParamTypeError("Invalid mode.")
        filter_condition = {} if filter_condition is None else filter_condition
        reply = mode_mapping[mode](filter_condition)

        return reply

    def _retrieve_all(self, filter_condition=None):
        """Retrieve metadata, root graph and watchpoint list."""
        if filter_condition:
            log.error("No filter condition required for retrieve all request.")
            raise DebuggerParamTypeError("filter_condition should be empty.")
        result = {}
        self._watch_point_id = 0
        self.cache_store.clean_data()
        log.info("Clean data queue cache when retrieve all request.")
        for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]:
            sub_res = self.cache_store.get_stream_handler(stream).get()
            result.update(sub_res)

        return result

    def _retrieve_node(self, filter_condition):
        """
        Retrieve node info.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, the node info.
        """
        log.info("Retrieve node %s.", filter_condition)
        node_name = filter_condition.get('name')
        if node_name:
            # validate node name
            self.cache_store.get_stream_handler(
                Streams.GRAPH).get_node_type(node_name)
        filter_condition['single_node'] = bool(
            filter_condition.get('single_node'))
        reply = self._get_nodes_info(filter_condition)
        return reply

    def _get_nodes_info(self, filter_condition):
        """
        Get nodes info.

        Args:
            filter_condition (dict): The filter condition.

                - name (str): The node name.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, reply with graph.
        """
        # get graph
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        reply = graph_stream.get(filter_condition)
        graph = reply.get('graph')
        # add watched label
        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).set_watch_nodes(graph, self._watch_point_id)
        return reply

    def retrieve_tensor_history(self, node_name):
        """
        Retrieve tensor history for leaf node.

        Args:
            node_name (str): The name of leaf node.

        Returns:
            dict, the tensor history and metadata.
        """
        log.info("Retrieve tensor history for node: %s.", node_name)
        self._validate_leaf_name(node_name)
        res = self._get_tensor_history(node_name)
        return res

    def _validate_leaf_name(self, node_name):
        """Validate if the node is a leaf node."""
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        node_type = graph_stream.get_node_type(node_name)
        if is_scope_type(node_type):
            log.error("Scope type node has no tensor history.")
            raise DebuggerParamValueError("Invalid leaf node name.")

    def _get_tensor_history(self, node_name):
        """
        Get tensor history for single node.

        Args:
            node_name (str): The name of leaf node.

        Returns:
            dict, the tensor history and metadata.
        """
        # get basic tensor history
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        tensor_history = graph_stream.get_tensor_history(node_name)
        # add tensor value for tensor history
        self._add_tensor_value_for_tensor_history(tensor_history, node_name)
        # add hit label for tensor history
        watchpoint_hit_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT)
        watchpoint_hit_stream.update_tensor_history(tensor_history)
        # add metadata
        metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
        tensor_history.update(metadata)
        return tensor_history

    def _add_tensor_value_for_tensor_history(self, tensor_history, node_name):
        """
        Add tensor value for_tensor_history and send ViewCMD if tensor value missed.

        Args:
            tensor_history (list[dict]): A list of tensor info, including name and type.
            node_name (str): The UI node name.

        Returns:
            dict, the tensor info.
        """
        tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
        missed_tensors = tensor_stream.update_tensor_history(tensor_history)
        if missed_tensors:
            view_cmd = create_view_event_from_tensor_history(missed_tensors)
            self.cache_store.put_command({
                'view_cmd': view_cmd,
                'node_name': node_name
            })
            log.debug("Send view cmd.")

    def retrieve_tensor_value(self, name, detail, shape):
        """Retrieve the tensor value."""
        log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s",
                 name, detail, shape)
        self.validate_tensor_param(name, detail)
        parsed_shape = self.parse_shape(shape)
        node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(
            name)
        reply = self.cache_store.get_stream_handler(Streams.TENSOR).get({
            'name':
            tensor_name,
            'node_type':
            node_type,
            'shape':
            parsed_shape
        })
        reply['tensor_value']['name'] = name

        return reply

    def _get_tensor_name_and_type_by_ui_name(self, name):
        """
        Get inner tensor name and type by UI name.

        Args:
            name (str): Node name shown in UI.

        Returns:
            str, full name of tensor.
            str, node type of tensor.
        """
        node_name, slot = name.rsplit(':', 1)
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        node_type = graph_stream.get_node_type(node_name)
        full_name = graph_stream.get_full_name(node_name)
        tensor_name = full_name + ':' + slot
        return node_type, tensor_name

    @staticmethod
    def validate_tensor_param(name, detail):
        """Validate params for retrieve tensor request."""
        # validate name
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        # validate data
        if detail != 'data':
            log.error("Invalid detail value. Received: %s", detail)
            raise DebuggerParamValueError("Invalid detail value.")

    @staticmethod
    def parse_shape(shape):
        """Parse shape."""
        if shape is None:
            return shape
        if not (isinstance(shape, str) and shape.startswith('[')
                and shape.endswith(']')):
            log.error("Invalid shape. Received: %s", shape)
            raise DebuggerParamValueError("Invalid shape.")
        shape = shape.strip('[]')
        if shape.count(':') > 2:
            log.error("Invalid shape. At most two dimensions are specified.")
            raise DebuggerParamValueError("Invalid shape.")
        parsed_shape = tuple(
            str_to_slice_or_int(dim)
            for dim in shape.split(',')) if shape else tuple()
        log.info("Parsed shape: %s from %s", parsed_shape, shape)
        return parsed_shape

    def _retrieve_watchpoint(self, filter_condition):
        """
        Retrieve watchpoint.

        Args:
            filter_condition (dict): Filter condition.

                - watch_point_id (int):  The id of watchoint. If not given, return all watchpoints.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        watchpoint_id = filter_condition.get('watch_point_id')
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoint_stream.validate_watchpoint_id(watchpoint_id)
        self._watch_point_id = watchpoint_id if watchpoint_id else 0
        if not watchpoint_id:
            reply = self.cache_store.get_stream_handler(
                Streams.WATCHPOINT).get()
            log.debug("Get condition of watchpoints.")
        else:
            reply = self._retrieve_node(filter_condition)
            log.debug("Get graph of %d-th watchpoint.", watchpoint_id)

        return reply

    def _retrieve_watchpoint_hit(self, filter_condition):
        """
        Retrieve watchpoint hit.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The name of single node.

                - single_node (bool): If False, return the sub-layer of single node. If True, return
                    the node list from root node to single node.

        Returns:
            dict, watch point list or relative graph.
        """
        node_name = filter_condition.get('name')
        # get watchpoint hit list
        if node_name is None:
            reply = self.cache_store.get_stream_handler(
                Streams.WATCHPOINT_HIT).get()
            return reply

        self._validate_leaf_name(node_name)
        # get tensor history
        reply = self._get_tensor_history(node_name)
        log.debug("Get tensor history for watchpoint hit node.")
        # get single graph
        if filter_condition.get('single_node'):
            graph = self._get_nodes_info(filter_condition)
            reply.update(graph)
        log.debug("Get tensor history for watchpoint hit node.")

        return reply

    def create_watchpoint(self,
                          watch_condition,
                          watch_nodes=None,
                          watch_point_id=None):
        """
        Create watchpoint.

        Args:
            watch_condition (dict): The watch condition.

                - condition (str): Accept `INF` or `NAN`.

                - param (list[float]): Not defined yet.
            watch_nodes (list[str]): The list of node names.
            watch_point_id (int): The id of watchpoint.

        Returns:
            dict, the id of new watchpoint.
        """
        log.info("Received create watchpoint request. WatchCondition: %s",
                 watch_condition)
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error(
                "Failed to create watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerCreateWatchPointError(
                "Failed to create watchpoint as the MindSpore is not in waiting state."
            )
        if metadata_stream.backend == 'GPU' and watch_condition.get(
                'condition') == 'OVERFLOW':
            log.error("GPU doesn't support OVERFLOW watch condition.")
            raise DebuggerParamValueError(
                "GPU doesn't support OVERFLOW watch condition.")

        watch_nodes = self._get_node_basic_infos(watch_nodes)
        watch_point_id = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).create_watchpoint(watch_condition, watch_nodes,
                                                  watch_point_id)
        self._watch_point_id = 0
        log.info("Create watchpoint %d", watch_point_id)
        return {'id': watch_point_id}

    def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None):
        """
        Update watchpoint.

        Args:
            watch_point_id (int): The id of watchpoint.
            watch_nodes (list[str]): The list of node names.
            mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
                1 for add nodes to watch nodes.
            name (str): The search name. Default: None.

        Returns:
            dict, empty response.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to update watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerUpdateWatchPointError(
                "Failed to update watchpoint as the MindSpore is not in waiting state."
            )
        # validate
        if not watch_nodes or not watch_point_id:
            log.error("Invalid parameter for update watchpoint.")
            raise DebuggerParamValueError(
                "Invalid parameter for update watchpoint.")
        # update watch node
        if name is not None:
            watch_nodes = self._get_watch_nodes_by_search(watch_nodes)
        elif mode == 1:
            watch_nodes = self._get_node_basic_infos(watch_nodes)

        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).update_watchpoint(watch_point_id, watch_nodes,
                                                  mode)
        self._watch_point_id = watch_point_id
        log.info("Update watchpoint with id: %d", watch_point_id)
        return {}

    def _get_watch_nodes_by_search(self, watch_nodes):
        """Get watched leaf nodes by search name."""
        watched_leaf_nodes = []
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        for search_name in watch_nodes:
            search_nodes = graph_stream.get_searched_node_list()
            search_node_names = [
                NodeBasicInfo(name=node.name,
                              full_name=node.full_name,
                              type=node.type) for node in search_nodes
                if node.name.startswith(search_name)
            ]
            watched_leaf_nodes.extend(search_node_names)

        log.debug("Update nodes: %s", watched_leaf_nodes)

        return watched_leaf_nodes

    def delete_watchpoint(self, watch_point_id):
        """
        Delete watchpoint.

        Args:
            watch_point_id (int): The id of watchpoint.

        Returns:
            dict, empty response.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to delete watchpoint as the MindSpore is not in waiting state."
            )
            raise DebuggerDeleteWatchPointError(
                "Failed to delete watchpoint as the MindSpore is not in waiting state."
            )
        self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).delete_watchpoint(watch_point_id)
        self._watch_point_id = 0
        log.info("Delete watchpoint with id: %d", watch_point_id)
        return {}

    def _get_node_basic_infos(self, node_names):
        """Get node info according to node names."""
        if not node_names:
            return []
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        node_infos = []
        for node_name in node_names:
            node_type = graph_stream.get_node_type(node_name)
            # optimizer later
            if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
                sub_nodes = graph_stream.get_nodes(node_name)
                sub_infos = [
                    NodeBasicInfo(name=node.name,
                                  full_name=node.full_name,
                                  type=node.type) for node in sub_nodes
                ]
                node_infos.extend(sub_infos)
                continue
            full_name = graph_stream.get_full_name(node_name)
            node_infos.append(
                NodeBasicInfo(name=node_name,
                              full_name=full_name,
                              type=node_type))
        return node_infos

    def control(self, params=None):
        """
        Control the training process.

        Args:
            params (dict): The control params.

                - mode (str): Acceptable control command, including `continue`,
                    `pause` and `terminate`.

                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.

                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.

                - name (str): Specify the name of the node. Used when `level` is `node`.

        Returns:
            dict, the response.
        """
        log.info("Receive control request: %s.", params)
        mode = params.get('mode')
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if mode == 'continue':
            reply = self._continue(metadata_stream, params)
        elif mode in ['pause', 'terminate']:
            mode_mapping = {'pause': self._pause, 'terminate': self._terminate}
            reply = mode_mapping.get(mode)(metadata_stream)
        else:
            log.error("Invalid control mode %s", mode)
            raise DebuggerParamValueError("Invalid control mode.")

        return reply

    def _continue(self, metadata_stream, params):
        """
        Send RunCMD to MindSpore.

        Args:
            metadata_stream (MetadataHandler): The metadata_handler
            params (dict): The control params.
        """
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error("MindSpore is not ready to run. Current state is: %s",
                      metadata_stream.state)
            raise DebuggerContinueError(
                "MindSpore is not ready to run or is running currently.")
        metadata_stream.state = ServerStatus.RUNNING.value
        current_state = ServerStatus.RUNNING.value
        try:
            event = self._construct_run_event(params)
            self._send_watchpoints()
            self.cache_store.put_command(event)
        except MindInsightException as err:
            log.error("Failed to send run event.")
            log.exception(err)
            current_state = ServerStatus.WAITING.value
            metadata_stream.state = current_state
            raise DebuggerContinueError("Failed to send run command.")
        else:
            log.debug("Send the RunCMD to command queue.")

        return {'metadata': {'state': current_state}}

    def _validate_node_type(self, node_name):
        """Check the node type in node control."""
        if not node_name:
            return
        node_type = self.cache_store.get_stream_handler(
            Streams.GRAPH).get_node_type(node_name)
        unsupported_types = [item.value for item in list(NodeTypeEnum)]
        if node_type in unsupported_types:
            log.error("Invalid node type. %s", node_name)
            raise DebuggerParamValueError(
                f"The type of node {node_name} is unsupported for "
                "continue to command.")

    def _construct_run_event(self, params):
        """
        Construct run cmd from input control params.

        Args:
            params (dict): The control params.

                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.

                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.

                - full_name (str): Specify the name of the node. Used when `level` is `node`.

        Returns:
            EventReply, control event with run command.
        """
        level = params.get('level', 'step')
        event = get_ack_reply()
        if level == 'step':
            steps = params.get('steps')
            if not steps:
                steps = 1
            run_cmd = RunCMD(run_level='step', run_steps=steps)
        elif level == 'node':
            self._validate_node_type(params.get('name'))
            name = self.cache_store.get_stream_handler(
                Streams.GRAPH).get_full_name(params['name'])
            if not name:
                name = ''
            run_cmd = RunCMD(run_level='node', node_name=name)
        else:
            log.error(
                "Invalid Value. `level` should be `step` or `node`. Got %s",
                level)
            raise DebuggerParamValueError("level` should be `step` or `node`")

        event.run_cmd.CopyFrom(run_cmd)
        log.debug("Construct run event. %s", event)
        return event

    def _send_watchpoints(self):
        """Set watchpoints."""
        watchpoint_stream = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        watchpoints = watchpoint_stream.get(
            filter_condition=True).get('watch_points')
        if watchpoints:
            for watchpoint in watchpoints:
                event = get_ack_reply()
                event.set_cmd.CopyFrom(watchpoint)
                self.cache_store.put_command(event)
            watchpoint_stream.sync_set_cmd()
            log.debug("Send SetCMD to MindSpore. %s", event)

    def _pause(self, metadata_stream):
        """
        Pause the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        if metadata_stream.state != ServerStatus.RUNNING.value:
            log.error("The MindSpore is not running.")
            raise DebuggerPauseError("The MindSpore is not running.")
        metadata_stream.state = 'waiting'
        event = get_ack_reply()
        event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
        self.cache_store.put_command(event)
        log.debug("Send the Pause command")
        return {'metadata': {'state': 'waiting'}}

    def _terminate(self, metadata_stream):
        """
        Terminate the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        metadata_stream.state = 'pending'
        event = get_ack_reply()
        event.exit = True
        self.cache_store.put_command(event)
        log.debug("Send the ExitCMD.")
        return {'metadata': {'state': 'pending'}}

    def retrieve_node_by_bfs(self, node_name, ascend=False):
        """
        Get the graph of the next node according to node_name.

        Args:
            node_name (str): The name of current chosen leaf node.
            ascend (bool): If True, traverse the input nodes;
                If False, traverse the output nodes. Default is True.

        Returns:
            dict, the next node information.
        """
        log.info("Retrieve node <%s> by bfs, `ascend` is :%s", node_name,
                 ascend)
        reply = {}
        graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
        next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend)
        # no next node
        if next_node_name is None:
            return reply
        # add graph and tensor history for next node
        filter_condition = {'name': next_node_name, 'single_node': True}
        search_graph = self._get_nodes_info(filter_condition)
        reply = {'name': next_node_name}
        reply.update(search_graph)

        return reply