Esempio n. 1
0
    def continue_training(self, params):
        """
        Send RunCMD to MindSpore.

        Args:
            params (dict): The control params.

        Returns:
            dict, metadata info.
        """
        metadata_stream = self._metadata_stream
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error("MindSpore is not ready to run. Current state is: %s",
                      metadata_stream.state)
            raise DebuggerContinueError(
                "MindSpore is not ready to run or is running currently.")
        metadata_stream.state = ServerStatus.RUNNING.value
        try:
            self._validate_continue_params(params)
            event = self._construct_run_event(params)
            self._send_watchpoints()
            self._cache_store.put_command(event)
        except MindInsightException as err:
            log.error("Failed to send run event.")
            log.exception(err)
            metadata_stream.state = ServerStatus.WAITING.value
            raise DebuggerContinueError("Failed to send run command.")
        else:
            metadata_stream.enable_recheck = False
            log.debug("Send the RunCMD to command queue.")
        return metadata_stream.get(['state', 'enable_recheck'])
    def download(self, name, prev, graph_name=None, rank_id=0):
        """
        Download the tensor value.

        Args:
            name (str): Node name shown in UI.
            prev (bool): The previous step or current step.
            graph_name (Union[str, None]): The graph name, default is: None.
            rank_id (int): The id of rank. Default: 0.

        Returns:
            str, the file path.
            str, the file name.
        """
        if not isinstance(name, str) or ':' not in name:
            log.error("Invalid tensor name. Received: %s", name)
            raise DebuggerParamValueError("Invalid tensor name.")
        _, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        log.info("Download the tensor value: name: %s", tensor_name)
        tensor_stream = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
        step = tensor_stream.cur_step
        if prev:
            step -= 1
        tensor_info = {
            "tensor_name": tensor_name,
            "graph_name": graph_name,
            "step": step,
            "rank_id": rank_id
        }
        return tensor_stream.download_mgr.get(**tensor_info)
 def get(self, filter_condition=None, rank_id=0):
     """Get the graph of specific node for specific device."""
     if rank_id in self.watchpoint_hit_handlers:
         return self.watchpoint_hit_handlers.get(rank_id).get(
             filter_condition)
     log.error("There is no rank id %d.", rank_id)
     raise ValueError
    def _retrieve_all(self, filter_condition=None):
        """Retrieve metadata, root graph and watchpoint list."""
        if filter_condition:
            log.error("No filter condition required for retrieve all request.")
            raise DebuggerParamTypeError("filter_condition should be empty.")
        self.cache_store.clean_data()
        log.info("Clean data queue cache when retrieve all request.")
        result = {}
        for stream in [Streams.METADATA, Streams.GRAPH, Streams.DEVICE]:
            sub_res = self.cache_store.get_stream_handler(stream).get()
            result.update(sub_res)

        devices = result['devices']
        if not devices:
            graph = result['graph']
            metadata = result['metadata']
            device = {
                'rank_id': 0,
                'server_ip': metadata.get('ip', 'localhost'),
                'device_id': metadata.get('device_name', ''),
                'graph_names': graph.get('graph_names', [])
            }
            devices.append(device)
        sub_res = self._hide_parameters_for_ui()
        result.update(sub_res)

        return result
    def recheck(self):
        """
        Recheck all watchpoints.

        Returns:
            dict, metadata info.
        """
        metadata_stream = self._metadata_stream
        # validate backend status is able to recheck watchpoint
        if not metadata_stream.enable_recheck:
            log.error("Recheck is not available.")
            raise DebuggerRecheckError("Recheck is not available.")
        metadata_stream.state = ServerStatus.SENDING.value
        metadata_stream.enable_recheck = False
        # send updated watchpoint and recheck command
        try:
            event = self._construct_run_event({'level': 'recheck'})
            self._send_watchpoints()
            self._cache_store.put_command(event)
        except MindInsightException as err:
            log.error("Failed to send recheck event.")
            log.exception(err)
            metadata_stream.state = ServerStatus.WAITING.value
            metadata_stream.enable_recheck = True
            raise DebuggerContinueError("Failed to send recheck command.")
        else:
            log.debug("Send the recheck to command queue.")
        return metadata_stream.get(['state', 'enable_recheck'])
Esempio n. 6
0
    def initialize(self):
        """Initialize the data_mode and net_dir of DataLoader."""
        self.load_rank_dirs()
        if not self._rank_dirs:
            log.error("No rank directory found under %s",
                      str(self._debugger_base_dir))
            raise RankDirNotFound(str(self._debugger_base_dir))
        rank_dir = self._rank_dirs[0].path
        dump_config = self._load_json_file(rank_dir / self.DUMP_METADATA /
                                           'data_dump.json')

        def _set_net_name():
            nonlocal dump_config
            common_settings = dump_config.get(
                DumpSettings.COMMON_DUMP_SETTINGS.value, {})
            try:
                self._net_name = common_settings['net_name']
            except KeyError:
                raise DebuggerJsonFileParseError("data_dump.json")

        def _set_dump_mode_and_device_target():
            nonlocal dump_config
            config_json = self._load_json_file(rank_dir / self.DUMP_METADATA /
                                               'config.json')
            self._device_target = config_json.get('device_target', 'Ascend')
            if self._device_target == 'GPU' or dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \
                    dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']:
                self._is_sync = True
            else:
                self._is_sync = False

        _set_net_name()
        _set_dump_mode_and_device_target()
Esempio n. 7
0
    def put(self, value):
        """
        Put value into event_cache.

        Args:
            value (dict): The event to be put into cache.
        """
        if not isinstance(value, dict):
            log.error("Dict type required when put event message.")
            raise DebuggerParamValueError(
                "Dict type required when put event message.")

        with self._lock:
            log.debug(
                "Put the %d-th message into queue. \n %d requests is waiting.",
                self._next_idx, len(self._pending_requests))
            cur_pos = self._next_idx
            # update next pos
            self._next_idx += 1
            if self._next_idx >= self.max_limit:
                self._next_idx = 0
                self._prev_flag = self._cur_flag
                self._cur_flag = str(uuid.uuid4())
            # set next pos
            if not value.get('metadata'):
                value['metadata'] = {}
            value['metadata']['pos'] = self.next_pos
            self._event_cache[cur_pos] = value
            # feed the value for pending requests
            self.clean_pending_requests(value)
    def reset_training_step(self, step_id):
        """
        Reset the training step.

        Args:
            step_id (int): The target step_id.

        Returns:
            dict, metadata info.
        """
        metadata_stream = self._metadata_stream
        if metadata_stream.debugger_type == DebuggerServerMode.ONLINE.value:
            log.error(
                "'step_id' can not be changed manually in online debugger.")
            return metadata_stream.get(['state', 'enable_recheck', 'step'])
        if step_id > metadata_stream.max_step_num:
            log.error("Invalid step_id, step_id should be less than %d.",
                      metadata_stream.max_step_num)
            raise DebuggerParamValueError("Invalid step_id.")
        metadata_stream.state = ServerStatus.SENDING.value
        metadata_stream.step = step_id
        self._cache_store.get_stream_handler(Streams.TENSOR).set_step(step_id)
        self._cache_store.clean_data()
        self._cache_store.clean_command()
        metadata_stream.enable_recheck = True
        metadata_stream.state = ServerStatus.WAITING.value
        self._cache_store.get_stream_handler(Streams.WATCHPOINT).set_outdated()
        log.debug("Send the Change_training_step CMD.")
        return metadata_stream.get(['state', 'enable_recheck', 'step'])
Esempio n. 9
0
    def search_watchpoint_hits(self, group_condition):
        """
        Retrieve watchpoint hit.

        Args:
            group_condition (dict): Filter condition.

                - limit (int): The limit of each page.
                - offset (int): The offset of current page.
                - node_name (str): The retrieved node name.
                - graph_name (str): The retrieved graph name.

        Returns:
            dict, watch point list or relative graph.
        """
        if not isinstance(group_condition, dict):
            log.error(
                "Group condition for watchpoint-hits request should be a dict")
            raise DebuggerParamTypeError(
                "Group condition for watchpoint-hits request should be a dict")

        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        reply = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT_HIT).group_by(group_condition)
        reply['outdated'] = self.cache_store.get_stream_handler(
            Streams.WATCHPOINT).is_recheckable()
        return reply
Esempio n. 10
0
    def get(self, filter_condition=None):
        """
        Get full tensor value.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The full name of tensor.
                - node_type (str): The type of the node.
                - prev (bool): Whether to get previous tensor.

        Returns:
            dict, the tensor_value.
        """
        name = filter_condition.get('name')
        node_type = filter_condition.get('node_type')
        shape = filter_condition.get('shape')
        if filter_condition.get('prev'):
            step = self.prev_step
        else:
            step = self.cur_step
        tensor = self._get_tensor(name, node_type, step)
        if not tensor:
            log.error("No tensor named %s at the step %s", name, step)
            raise DebuggerParamValueError("No tensor named {}".format(name))
        tensor_info = tensor.get_full_info(shape)
        self._update_has_prev_step_field(tensor_info, name, node_type)
        return {'tensor_value': tensor_info}
Esempio n. 11
0
    def create_session(self, session_type, train_job=None):
        """
        Create the session by the train job info or session type if the session doesn't exist.

        Args:
            session_type (str): The session_type.
            train_job (str): The train job info.

        Returns:
            str, session id.
        """
        with self._lock:
            if self._exiting:
                logger.info("System is exiting, will terminate the thread.")
                _thread.exit()

            if session_type == self.ONLINE_TYPE:
                if self.ONLINE_SESSION_ID not in self.sessions:
                    logger.error(
                        'Online session is unavailable, set --enable-debugger as true/1 to enable debugger '
                        'when start Mindinsight server.')
                    raise DebuggerOnlineSessionUnavailable()
                return self.ONLINE_SESSION_ID

            if train_job in self.train_jobs:
                return self.train_jobs.get(train_job)

            return self._create_offline_session(train_job)
Esempio n. 12
0
    def get_tensor_history(self, node_name, graph_name=None, depth=0):
        """
        Get the tensor history of a specified node.

        Args:
            node_name (str): The debug name of the node.
            graph_name (str): The graph_name. Default: None.
            depth (int): The number of layers the user
                wants to trace. Default is 0.

        Returns:
            dict, basic tensor history, only including tensor name and tensor type and node type.
        """
        graph_name, node_name = self._parse_node_name(node_name, graph_name)
        graph = self._get_graph(graph_name=graph_name, node_name=node_name)
        # validate node type, scope node has no tensor history
        node_type = graph.get_node_type(node_name)
        if is_scope_type(node_type):
            log.error("Scope type node has no tensor history.")
            raise DebuggerParamValueError("Invalid leaf node name.")
        # get tensor history
        tensor_history, cur_outputs_nums = graph.get_tensor_history(
            node_name, depth)
        # add the tensor type for tensor history
        self._update_tensor_history(tensor_history[0:cur_outputs_nums],
                                    'output', graph_name)
        self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input',
                                    graph_name)
        log.debug("Get %d tensors in tensor history for node <%s>.",
                  len(tensor_history), node_name)
        return {'tensor_history': tensor_history}
Esempio n. 13
0
    def retrieve(self, mode, filter_condition=None):
        """
        Retrieve data according to mode and params.

        Args:
            mode (str): The type of info message.
            filter_condition (dict): The filter condition.

        Returns:
            dict, the retrieved data.
        """
        log.info(
            "receive retrieve request for mode:%s\n, filter_condition: %s",
            mode, filter_condition)
        mode_mapping = {
            'all': self._retrieve_all,
            'node': self._retrieve_node,
            'watchpoint': self._retrieve_watchpoint,
        }
        # validate param <mode>
        if mode not in mode_mapping.keys():
            log.error(
                "Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', "
                "'watchpoint_hit'], but got %s.", mode_mapping)
            raise DebuggerParamValueError("Invalid mode.")
        # validate backend status
        metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
        if metadata_stream.state == ServerStatus.PENDING.value:
            log.info("The backend is in pending status.")
            return metadata_stream.get()

        filter_condition = {} if filter_condition is None else filter_condition
        reply = mode_mapping[mode](filter_condition)

        return reply
    def _get_by_offset(self, group_condition):
        """Return the list of watchpoint hits on the offset page."""
        limit = group_condition.get('limit')
        offset = group_condition.get('offset')
        if not isinstance(limit, int) or not isinstance(offset, int):
            log.error("Param limit or offset is not a integer")
            raise DebuggerParamValueError(
                "Param limit or offset is not a integer")
        watch_point_hits = []

        total = len(self._ordered_hits)

        if limit * offset >= total and offset != 0:
            log.error("Param offset out of bounds")
            raise DebuggerParamValueError("Param offset out of bounds")

        if total == 0:
            return {}

        for watchpoint_hits in self._ordered_hits[(limit *
                                                   offset):(limit *
                                                            (offset + 1))]:
            self._get_tensors(watchpoint_hits, watch_point_hits)

        return {
            'watch_point_hits': watch_point_hits,
            'offset': offset,
            'total': total
        }
Esempio n. 15
0
    def get(self, filter_condition=None):
        """
        Get full tensor value.

        Args:
            filter_condition (dict): Filter condition.

                - name (str): The full name of tensor.
                - node_type (str): The type of the node.
                - prev (bool): Whether to get previous tensor.

        Returns:
            dict, the tensor_value and whether need to send view_command.
        """
        name = filter_condition.get('name')
        node_type = filter_condition.get('node_type')
        shape = filter_condition.get('shape')
        if filter_condition.get('prev'):
            step = self.prev_step
        else:
            step = self.cur_step
        tensor = self._get_tensor(name, node_type, step)
        if not tensor:
            log.error("No tensor named %s at the step %s", name, step)
            raise DebuggerParamValueError("No tensor named {}".format(name))
        tensor_info = tensor.get_full_info(shape)
        self._update_has_prev_step_field(tensor_info, name, node_type,
                                         self.cur_step)
        res = {'tensor_value': tensor_info, 'view_cmd': False}
        if tensor.status == TensorStatusEnum.UNCACHED.value:
            self._add_hold_value_tensors(name, step)
            res['view_cmd'] = True
        return res
    def create_watchpoint(self, params):
        """
        Create watchpoint.

        Args:
            - watch_condition (dict): The watch condition. The format is like:
                    {
                        "id": "tensor_too_large",
                        "params": [
                            {
                                "name": "abs_mean_gt",
                                "value": 1.1
                            }
                        ]
                    }

                    - id (str): Id of condition.
                    - params (list[dict]): The list of param for this condition.
                - watch_nodes (list[str]): The list of node names.
                - watch_point_id (int): The id of watchpoint.
                - search_pattern (dict): The search pattern.
                - graph_name (str): The relative graph_name of the watched node.

        Returns:
            dict, the id of new watchpoint and metadata info.
        """
        watch_condition = params.get('watch_condition')
        log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
        metadata_stream = self._metadata_stream
        if metadata_stream.state != ServerStatus.WAITING.value:
            log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
            raise DebuggerCreateWatchPointError(
                "Failed to create watchpoint as the MindSpore is not in waiting state.")
        self._validate_watch_condition(watch_condition)

        watch_nodes = self._get_watch_node_with_basic_info(
            node_names=params.get('watch_nodes'),
            search_pattern=params.get('search_pattern'),
            graph_name=params.get('graph_name'))

        validate_watch_condition(self._condition_mgr, watch_condition)
        condition_id = watch_condition.get('id')
        condition = self._condition_mgr.get_condition(condition_id)
        condition_context = ConditionContext(metadata_stream.backend, metadata_stream.step)
        if not condition.is_available(condition_context):
            log.error("Failed to create watchpoint as the condition is not available.")
            raise DebuggerConditionUnavailableError(
                "Failed to create watchpoint as the condition is not available.")

        watch_nodes = get_basic_node_info(condition.supported_target_type.value, self._graph_stream).copy()
        watchpoint_stream = self._watchpoint_stream
        watch_point_id = watchpoint_stream.create_watchpoint(
            self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id'))
        log.info("Create watchpoint %d", watch_point_id)

        metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
        res = metadata_stream.get(['state', 'enable_recheck'])
        res['id'] = watch_point_id
        return res
Esempio n. 17
0
 def get(self, **tensor_info):
     """Get the temp file path."""
     with self._lock:
         if self.tensor_info == tensor_info:
             self.status = DownloadStatusEnum.SENDING.value
             return self.file_name, self.file_path, self.clean
     log.error("No such tensor to download")
     raise DebuggerDownloadTensorNotExist()
Esempio n. 18
0
 def validate_graph_name(self, graph_name):
     """Validate graph_name."""
     if graph_name and self._graph.get(graph_name) is None:
         log.error("No graph named %s in debugger cache.", graph_name)
         raise DebuggerGraphNotExistError
     if not graph_name and len(self._graph) == 1:
         graph_name = self.graph_names[0]
     return graph_name
 def _validate_continue_node_name(self, node_name, graph_name):
     """Validate if the node is a leaf node."""
     if not node_name:
         return
     node_type = self._graph_stream.get_node_type(node_name, graph_name)
     if is_scope_type(node_type):
         log.error("Scope type node has no tensor history.")
         raise DebuggerParamValueError("Invalid leaf node name.")
Esempio n. 20
0
 def remove(self, sub_name):
     """Remove sub node."""
     try:
         self._children.pop(sub_name)
     except KeyError as err:
         log.error("Failed to find node %s. %s", sub_name, err)
         raise DebuggerParamValueError(
             "Failed to find node {}".format(sub_name))
Esempio n. 21
0
 def _validate_watch_condition(self, watch_condition):
     """Validate watch condition."""
     metadata_stream = self._metadata_stream
     if metadata_stream.backend == 'GPU' and watch_condition.get(
             'id') == ConditionIdEnum.OPERATOR_OVERFLOW.value:
         log.error("GPU doesn't support overflow watch condition.")
         raise DebuggerParamValueError(
             "GPU doesn't support overflow watch condition.")
 def _get_dbg_service_module():
     """Get dbg service module from MindSpore."""
     try:
         dbg_services_module = import_module(
             'mindspore.offline_debug.dbg_services')
     except (ModuleNotFoundError, ImportError) as err:
         log.error("Failed to find module dbg_services. %s", err)
         raise DebuggerModuleNotFoundError("dbg_services")
     return dbg_services_module
Esempio n. 23
0
    def get_tensors_diff(self, tensor_name, shape, tolerance=0):
        """
            Get tensor comparisons data for given name, detail, shape and tolerance.

        Args:
            tensor_name (str): The name of tensor for cache.
            shape (tuple): Specify concrete dimensions of shape.
            tolerance (str): Specify tolerance of difference between current step tensor and previous
                step tensor. Default value is 0. Its is a percentage. The boundary value is equal to
                max(abs(min),abs(max)) * tolerance. The function of min and max is being used to
                calculate the min value and max value of the result of the current step tensor subtract
                the previous step tensor. If the absolute value of result is less than or equal to
                boundary value, the result will set to be zero.

        Raises:
            DebuggerParamValueError, If get current step node and previous step node failed or
                the type of tensor value is not numpy.ndarray."

        Returns:
            dict, the retrieved data.
        """
        curr_tensor = self.get_valid_tensor_by_name(tensor_name)
        prev_tensor = self.get_valid_tensor_by_name(tensor_name, prev=True)
        if not (curr_tensor and prev_tensor):
            log.error("Get current step and previous step for this tensor name %s failed.", tensor_name)
            raise DebuggerParamValueError(f"Get current step and previous step for this tensor name "
                                          f"{tensor_name} failed.")
        curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape)
        prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
        # get tensor comparison basic info
        tensor_info = curr_tensor.get_basic_info()
        tensor_info.pop('has_prev_step')
        tensor_info.pop('value')
        # calculate tensor comparision object
        tensor_comparison = curr_tensor.tensor_comparison
        if not tensor_comparison or tensor_comparison.tolerance != tolerance:
            if curr_tensor.value.shape != prev_tensor.value.shape:
                raise DebuggerParamValueError("The shape of these two step tensors is not the same.")
            tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance)
            stats = TensorUtils.get_statistics_from_tensor(tensor_diff)
            tensor_comparison = TensorComparison(tolerance, stats, tensor_diff)
            curr_tensor.update_tensor_comparisons(tensor_comparison)
        # calculate diff value
        # the type of curr_tensor_slice is one of np.ndarray or str
        if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
            if not shape:
                tensor_diff_slice = tensor_comparison.value
            else:
                tensor_diff_slice = tensor_comparison.value[shape]
            result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1)
            tensor_info['diff'] = result.tolist()
        elif isinstance(curr_tensor_slice, str):
            tensor_info['diff'] = curr_tensor_slice
        # add comparision statistics
        tensor_info.update(self._get_comparison_statistics(curr_tensor, prev_tensor))
        reply = {'tensor_value': tensor_info}
        return reply
Esempio n. 24
0
 def validate_tensor_param(name, detail):
     """Validate params for retrieve tensor request."""
     # validate name
     if not isinstance(name, str) or ':' not in name:
         log.error("Invalid tensor name. Received: %s", name)
         raise DebuggerParamValueError("Invalid tensor name.")
     # validate data
     if detail != 'data':
         log.error("Invalid detail value. Received: %s", detail)
         raise DebuggerParamValueError("Invalid detail value.")
Esempio n. 25
0
    def get_full_name_by_node_name(self, node_name):
        """Get full name by node name."""
        if not node_name:
            return ''
        node = self._normal_node_map.get(node_name)
        if not node:
            log.error("Node <%s> is not in graph.", node_name)
            raise DebuggerNodeNotInGraphError(node_name=node_name)

        return node.full_name
Esempio n. 26
0
 def add(self, file_name, file_path, temp_dir, **tensor_info):
     """Add the temp file path."""
     with self._lock:
         if self.status != DownloadStatusEnum.SENDING.value:
             self.file_name = file_name
             self.file_path = file_path
             self.temp_dir = temp_dir
             self.tensor_info = tensor_info
             return
     log.error("There is already a tensor in download")
     raise DebuggerDownloadOverQueue()
Esempio n. 27
0
    def _graph_exists(self):
        """
        Check if the graph has been loaded in the debugger cache.

        Raises:
            DebuggerGraphNotExistError: If the graph does not exist.
        """
        if not self._graph:
            log.error('The graph does not exist. Please start the '
                      'training script and try again.')
            raise DebuggerGraphNotExistError
Esempio n. 28
0
 def validate_watchpoint_id(self, watch_point_id):
     """Validate watchpoint id."""
     if not isinstance(watch_point_id, int):
         log.error(
             "Invalid watchpoint id %s. The watch point id should be int.",
             watch_point_id)
         raise DebuggerParamTypeError("Watchpoint id should be int type.")
     if watch_point_id and watch_point_id not in self._watchpoints:
         log.error("Invalid watchpoint id: %d.", watch_point_id)
         raise DebuggerParamValueError(
             "Invalid watchpoint id: {}".format(watch_point_id))
Esempio n. 29
0
    def tensor_comparisons(self,
                           name,
                           shape,
                           detail='data',
                           tolerance='0',
                           rank_id=0,
                           graph_name=None):
        """
        Get tensor comparisons data for given name, detail, shape and tolerance.

        Args:
            name (str): The name of tensor for ui.
            shape (str): Specify concrete dimensions of shape.
            detail (str): Specify which data to query. Current available value is 'data' which means
                          concrete tensor data. Histogram or unique count can be supported in the future.
            rank_id (int): The id of rank. Default: 0.
            tolerance (str): Specify tolerance of difference between current step tensor and previous
                             step tensor. Default value is 0.
            graph_name (str): The graph name. Default: None.

        Returns:
            dict, the retrieved data.
        """
        if self.cache_store.get_stream_handler(
                Streams.METADATA).state != ServerStatus.WAITING.value:
            log.error(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
            raise DebuggerCompareTensorError(
                "Failed to compare tensors as the MindSpore is not in waiting state."
            )
        self.validate_tensor_param(name, detail)
        # Limit to query max two dimensions for tensor in table view.
        parsed_shape = TensorUtils.parse_shape(shape,
                                               limit=MAX_DIMENSIONS_FOR_TENSOR)
        node_type, tensor_name, graph_name = self._get_tensor_name_and_type_by_ui_name(
            name, graph_name, rank_id)
        tolerance = to_float(tolerance, 'tolerance')
        tensor_stream = self.cache_store.get_stream_handler(
            Streams.TENSOR).get_tensor_handler_by_rank_id(rank_id)
        cur_step = self.cache_store.get_stream_handler(Streams.METADATA).step
        if node_type == NodeTypeEnum.PARAMETER.value:
            reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape,
                                                   tolerance, cur_step)
        else:
            raise DebuggerParamValueError(
                "The node type must be parameter, but got {}.".format(
                    node_type))
        if reply.pop('view_cmd', False):
            self._send_view_cmd(name, graph_name, rank_id, tensor_name,
                                node_type)
        return reply
Esempio n. 30
0
 def get_tensor_handler_by_rank_id(self,
                                   rank_id=0,
                                   create_if_not_exit=False):
     """get handler by rank id"""
     if rank_id in self.tensor_handlers:
         return self.tensor_handlers.get(rank_id)
     if create_if_not_exit:
         tensor_handler = TensorHandler(self._memory_mgr,
                                        self._download_mgr,
                                        rank_id=rank_id)
         self.tensor_handlers[rank_id] = tensor_handler
         return tensor_handler
     log.error("There is no rank id %d in MultiCardTensorHandler.", rank_id)
     raise ValueError