Example #1
0
    def WaitCMD(self, request, context):
        """Wait for a command in DebuggerCache."""
        # check if graph have already received.
        log.info("Received WaitCMD at %s-th step.", request.cur_step)
        if self._status == ServerStatus.PENDING:
            log.warning("No graph received before WaitCMD.")
            reply = get_ack_reply(1)
            return reply
        self._send_received_tensor_tag()
        # send graph if has not been sent before
        self._pre_process(request)
        # deal with old command
        reply = self._deal_with_old_command()
        if reply:
            log.info("Reply to WaitCMD with old command: %s", reply)
            return reply
        # continue multiple steps training
        if self._continue_steps:
            reply = get_ack_reply()
            reply.run_cmd.run_steps = 1
            reply.run_cmd.run_level = 'step'
            self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1
            self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean()
            log.debug("Send RunCMD. Clean watchpoint hit.")
        # wait for command
        else:
            reply = self._wait_for_next_command()

        if reply is None:
            reply = get_ack_reply(1)
            log.warning("Failed to get command event.")
        else:
            log.info("Reply to WaitCMD: %s", reply)
        return reply
 def SendHeartbeat(self, request, context):
     """Deal with heartbeat sent from training client."""
     # only support single card training now
     if self._heartbeat is None or not self._heartbeat.is_alive():
         period = request.period
         min_period_seconds = 5
         max_period_seconds = 3600
         if not min_period_seconds <= period <= max_period_seconds:
             log.error("Invalid period time which should be in [5, 3600].")
             return get_ack_reply(-1)
         self._heartbeat = HeartbeatListener(self._cache_store,
                                             request.period)
         self._heartbeat.start()
     self._heartbeat.send()
     return get_ack_reply()
Example #3
0
 def SendTensors(self, request_iterator, context):
     """Send tensors into DebuggerCache."""
     log.info("Received tensor.")
     tensor_construct = []
     tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
     metadata_stream = self._cache_store.get_stream_handler(
         Streams.METADATA)
     tensor_names = []
     step = metadata_stream.step
     for tensor in request_iterator:
         tensor_construct.append(tensor)
         if tensor.finished:
             update_flag = tensor_stream.put({
                 'step':
                 step,
                 'tensor_protos':
                 tensor_construct
             })
             if self._received_view_cmd.get(
                     'wait_for_tensor') and update_flag:
                 # update_flag is used to avoid querying empty tensors again
                 self._received_view_cmd['wait_for_tensor'] = False
                 log.debug("Set wait for tensor flag to False.")
             tensor_construct = []
             tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
             continue
     reply = get_ack_reply()
     return reply
Example #4
0
    def SendWatchpointHits(self, request_iterator, context):
        """Send watchpoint hits info DebuggerCache."""
        log.info("Received WatchpointHits. Left run cmd %s change to emtpy.",
                 self._old_run_cmd)
        self._old_run_cmd.clear()
        if self._cache_store.get_stream_handler(
                Streams.METADATA).state == ServerStatus.RUNNING.value:
            # if the client session is running a script, all the cached command should be cleared
            # when received watchpoint_hits.
            self._cache_store.clean_command()

        # save the watchpoint_hits data
        watchpoint_hits = []
        watchpoint_stream = self._cache_store.get_stream_handler(
            Streams.WATCHPOINT)
        graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
        for watchpoint_hit_proto in request_iterator:
            node_full_name = watchpoint_hit_proto.tensor.node_name
            graph_name = graph_stream.get_graph_id_by_full_name(node_full_name)
            if not graph_name:
                log.warning("Cannot find node %s in graph. Skip it.",
                            node_full_name)
                continue
            ui_node_name = graph_stream.get_node_name_by_full_name(
                node_full_name, graph_name)
            log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
            if not ui_node_name:
                log.info("Not support to show %s on graph.", node_full_name)
                continue
            watchpoint_hit = {
                'tensor_proto':
                watchpoint_hit_proto.tensor,
                'watchpoint':
                copy.deepcopy(
                    watchpoint_stream.get_watchpoint_by_id(
                        watchpoint_hit_proto.id)),
                'node_name':
                ui_node_name,
                'graph_name':
                graph_name
            }
            hit_params = {}
            for param in watchpoint_hit_proto.watch_condition.params:
                if param.actual_value:
                    hit_params[param.name] = param.actual_value
            for i, param in enumerate(
                    watchpoint_hit['watchpoint'].condition['params']):
                name = param['name']
                if name in hit_params.keys():
                    watchpoint_hit['watchpoint'].condition['params'][i][
                        'actual_value'] = hit_params[name]
                else:
                    watchpoint_hit['watchpoint'].condition['params'][i][
                        'actual_value'] = None
            if watchpoint_hit_proto.error_code:
                watchpoint_hit['error_code'] = watchpoint_hit_proto.error_code
            watchpoint_hits.append(watchpoint_hit)
        self._received_hit = watchpoint_hits
        reply = get_ack_reply()
        return reply
 def test_send_watchpoint_hit(self, *args):
     """Test SendWatchpointHits interface."""
     args[0].side_effect = [None, 'mock_full_name']
     watchpoint_hit = MockDataGenerator.get_watchpoint_hit()
     res = self._server.SendWatchpointHits([watchpoint_hit, watchpoint_hit],
                                           MagicMock())
     assert res == get_ack_reply()
 def SendGraph(self, request_iterator, context):
     """Send graph into DebuggerCache."""
     log.info("Received graph.")
     reply = get_ack_reply()
     if self._status == ServerStatus.MISMATCH:
         log.info(
             "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service."
         )
         return reply
     serial_graph = b""
     for chunk in request_iterator:
         serial_graph += chunk.buffer
     graph = GraphProto.FromString(serial_graph)
     log.debug("Deserialize the graph %s. Receive %s nodes", graph.name,
               len(graph.node))
     graph_dict = {graph.name: graph}
     self._cache_store.get_stream_handler(Streams.GRAPH).put(
         {0: graph_dict})
     self._cache_store.get_stream_handler(Streams.GRAPH).parse_stack_infos()
     self._cache_store.get_stream_handler(
         Streams.TENSOR).get_tensor_handler_by_rank_id(0).put_const_vals(
             graph.const_vals)
     self._cache_store.get_stream_handler(
         Streams.METADATA).graph_name = graph.name
     self._record_parameter_names()
     self._status = ServerStatus.RECEIVE_GRAPH
     log.debug("Send the reply for graph.")
     return reply
 def test_send_matadata_with_mismatched(self):
     """Test SendMatadata interface."""
     res = self._server.SendMetadata(
         MagicMock(training_done=False, ms_version='1.0.0'), MagicMock())
     expect_reply = get_ack_reply()
     expect_reply.version_matched = False
     assert res == expect_reply
 def test_waitcmd_with_next_command_is_none(self, *args):
     """Test wait command interface with next command is None."""
     args[0].return_value = None
     setattr(self._server, '_status', ServerStatus.RECEIVE_GRAPH)
     res = self._server.WaitCMD(MagicMock(cur_step=1, cur_node=''),
                                MagicMock())
     assert res == get_ack_reply(1)
Example #9
0
 def SendWatchpointHits(self, request_iterator, context):
     """Send watchpoint hits info DebuggerCache."""
     log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
     self._continue_steps = 0
     watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
     watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
     graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)
     for watchpoint_hit_proto in request_iterator:
         ui_node_name = graph_stream.get_node_name_by_full_name(
             watchpoint_hit_proto.tensor.node_name)
         log.debug("Receive watch point hit: %s", watchpoint_hit_proto)
         if not ui_node_name:
             log.info("Not support to show %s on graph.", watchpoint_hit_proto.tensor.node_name)
             continue
         watchpoint_hit = {
             'tensor_proto': watchpoint_hit_proto.tensor,
             'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id),
             'node_name': ui_node_name
         }
         watchpoint_hit_stream.put(watchpoint_hit)
     watchpoint_hits_info = watchpoint_hit_stream.get()
     self._cache_store.put_data(watchpoint_hits_info)
     log.info("Send the watchpoint hits to DataQueue.\nSend the reply.")
     reply = get_ack_reply()
     return reply
Example #10
0
    def SendMultiGraphs(self, request_iterator, context):
        """Send graph into DebuggerCache."""
        log.info("Received multi_graphs.")
        reply = get_ack_reply()
        if self._status == ServerStatus.MISMATCH:
            log.info(
                "Mindspore and Mindinsight is unmatched, waiting for user to terminate the service."
            )
            return reply
        serial_graph = b""
        graph_dict = {}
        for chunk in request_iterator:
            serial_graph += chunk.buffer
            if chunk.finished:
                sub_graph = GraphProto.FromString(serial_graph)
                graph_dict[sub_graph.name] = sub_graph
                log.debug("Deserialize the graph %s. Receive %s nodes",
                          sub_graph.name, len(sub_graph.node))
                serial_graph = b""
                self._cache_store.get_stream_handler(
                    Streams.TENSOR).put_const_vals(sub_graph.const_vals)

        self._cache_store.get_stream_handler(Streams.GRAPH).put(graph_dict)
        self._record_parameter_names()
        self._status = ServerStatus.RECEIVE_GRAPH
        log.debug("Send the reply for graph.")
        return reply
 def get_view_cmd():
     """Get set command"""
     view_event = get_ack_reply()
     ms_tensor = view_event.view_cmd.tensors.add()
     ms_tensor.node_name, ms_tensor.slot = 'mock_node_name', '0'
     event = {'view_cmd': view_event, 'node_name': 'mock_node_name', 'graph_name': 'mock_graph_name'}
     return event
 def get_run_cmd(steps=0, level='step', node_name=''):
     """Get run command."""
     event = get_ack_reply()
     event.run_cmd.run_level = level
     if level == 'node':
         event.run_cmd.node_name = node_name
     else:
         event.run_cmd.run_steps = steps
     return event
 def stop(self):
     """Stop server."""
     self._is_running_flag = False
     self._command_listener.stop()
     self._cache_store.clean()
     event = get_ack_reply()
     event.exit = True
     self._cache_store.put_command(event)
     log.info("Stop debugger offline manager.")
 def _send_watchpoints(self):
     """Send watchpoints to client."""
     set_commands = self._watchpoint_stream.get_pending_commands(
         self._graph_stream)
     if not set_commands:
         return
     for set_cmd in set_commands:
         event = get_ack_reply()
         event.set_cmd.CopyFrom(set_cmd)
         self._cache_store.put_command(event)
         log.debug("Send SetCMD to MindSpore. %s", event)
Example #15
0
    def _terminate(self, metadata_stream):
        """
        Terminate the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        metadata_stream.state = 'pending'
        event = get_ack_reply()
        event.exit = True
        self.cache_store.put_command(event)
        log.debug("Send the ExitCMD.")
        return {'metadata': {'state': 'pending'}}
Example #16
0
 def _send_watchpoints(self):
     """Set watchpoints."""
     watchpoint_stream = self.cache_store.get_stream_handler(
         Streams.WATCHPOINT)
     watchpoints = watchpoint_stream.get(
         filter_condition=True).get('watch_points')
     if watchpoints:
         for watchpoint in watchpoints:
             event = get_ack_reply()
             event.set_cmd.CopyFrom(watchpoint)
             self.cache_store.put_command(event)
         watchpoint_stream.sync_set_cmd()
         log.debug("Send SetCMD to MindSpore. %s", event)
Example #17
0
 def SendGraph(self, request_iterator, context):
     """Send graph into DebuggerCache."""
     log.info("Received graph.")
     serial_graph = b""
     for chunk in request_iterator:
         serial_graph += chunk.buffer
     graph = GraphProto.FromString(serial_graph)
     log.debug("Deserialize the graph. Receive %s nodes", len(graph.node))
     self._cache_store.get_stream_handler(Streams.GRAPH).put(graph)
     self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals)
     self._status = ServerStatus.RECEIVE_GRAPH
     reply = get_ack_reply()
     log.info("Send the reply for graph.")
     return reply
Example #18
0
    def WaitCMD(self, request, context):
        """Wait for a command in DebuggerCache."""
        # check if graph have already received.
        log.info("Received WaitCMD at %s-th step.", request.cur_step)
        if self._status == ServerStatus.PENDING:
            log.warning("No graph received before WaitCMD.")
            reply = get_ack_reply(1)
            return reply

        # send graph if it has not been sent before
        self._pre_process(request)
        # deal with old command
        reply = self._deal_with_old_command()
        # wait for next command
        if reply is None:
            reply = self._wait_for_next_command()
        # check the reply
        if reply is None:
            reply = get_ack_reply(1)
            log.warning("Failed to get command event.")
        else:
            log.debug("Reply to WaitCMD: %s", reply)
        return reply
Example #19
0
    def _pause(self, metadata_stream):
        """
        Pause the training.

        Args:
            metadata_stream (MetadataHandler): The metadata stream handler.
        """
        if metadata_stream.state != ServerStatus.RUNNING.value:
            log.error("The MindSpore is not running.")
            raise DebuggerPauseError("The MindSpore is not running.")
        metadata_stream.state = 'waiting'
        event = get_ack_reply()
        event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
        self.cache_store.put_command(event)
        log.debug("Send the Pause command")
        return {'metadata': {'state': 'waiting'}}
 def SendTensors(self, request_iterator, context):
     """Send tensors into DebuggerCache."""
     log.info("Received tensor.")
     tensor_contents = []
     tensor_stream = self._cache_store.get_stream_handler(
         Streams.TENSOR).get_tensor_handler_by_rank_id(0)
     metadata_stream = self._cache_store.get_stream_handler(
         Streams.METADATA)
     step = metadata_stream.step
     oversize = False
     node_info = self._received_view_cmd.get('node_info')
     if node_info and node_info.get('load'):
         load_info = node_info.get('load')
         load_info['graph_name'] = node_info.get('graph_name')
         load_info['node_name'] = node_info.get('node_name')
         # The rank_id of online debugger is 0
         load_tensor(load_info=load_info,
                     step=step,
                     request_iterator=request_iterator,
                     cache_store=self._cache_store,
                     rank_id=0)
     else:
         for tensor in request_iterator:
             tensor_contents.append(tensor.tensor_content)
             if len(tensor_contents) >= MAX_SINGLE_TENSOR_CACHE or oversize:
                 oversize = True
                 tensor_contents = []
             if tensor.finished:
                 value = {
                     'step': step,
                     'tensor_proto': tensor,
                     'tensor_contents': tensor_contents,
                     'oversize': oversize,
                     'stats': bool(node_info and node_info.get('stats'))
                 }
                 update_flag = tensor_stream.put(value)
                 if self._received_view_cmd.get(
                         'wait_for_tensor') and update_flag:
                     # update_flag is used to avoid querying empty tensors again
                     self._received_view_cmd['wait_for_tensor'] = False
                     log.debug("Set wait for tensor flag to False.")
                 tensor_contents = []
                 oversize = False
                 continue
     reply = get_ack_reply()
     return reply
    def terminate_training(self):
        """
        Terminate the training.

        Returns:
            dict, metadata info.
        """
        metadata_stream = self._metadata_stream
        metadata_stream.state = ServerStatus.SENDING.value
        self._cache_store.clean_data()
        self._cache_store.clean_command()
        event = get_ack_reply()
        event.exit = True
        self._cache_store.put_command(event)
        metadata_stream.enable_recheck = False
        log.debug("Send the ExitCMD.")
        return metadata_stream.get(['state', 'enable_recheck'])
Example #22
0
    def pause_training(self):
        """
        Pause the training.

        Returns:
            dict, metadata info.
        """
        metadata_stream = self._metadata_stream
        if metadata_stream.state != ServerStatus.RUNNING.value:
            log.error("The MindSpore is not running.")
            raise DebuggerPauseError("The MindSpore is not running.")
        metadata_stream.state = 'waiting'
        event = get_ack_reply()
        event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0))
        self._cache_store.put_command(event)
        metadata_stream.enable_recheck = False
        log.debug("Send the Pause command")
        return metadata_stream.get(['state', 'enable_recheck'])
Example #23
0
    def SendMetadata(self, request, context):
        """Send metadata into DebuggerCache."""
        log.info("Received Metadata.")
        if self._status != ServerStatus.PENDING:
            log.info("Re-initialize cache store when new session comes.")
            self.init()

        client_ip = context.peer().split(':', 1)[-1]
        metadata_stream = self._cache_store.get_stream_handler(
            Streams.METADATA)
        reply = get_ack_reply()
        if request.training_done:
            log.info("The training from %s has finished.", client_ip)
        else:
            ms_version = request.ms_version
            if not ms_version:
                ms_version = '1.0.x'
            if version_match(ms_version, mindinsight.__version__) is False:
                log.info(
                    "Version is mismatched, mindspore is: %s, mindinsight is: %s",
                    ms_version, mindinsight.__version__)
                self._status = ServerStatus.MISMATCH
                reply.version_matched = False
                metadata_stream.state = 'mismatch'
            else:
                log.info("version is matched.")
                reply.version_matched = True

            metadata_stream.debugger_version = {
                'ms': ms_version,
                'mi': mindinsight.__version__
            }
            log.debug("Put ms_version from %s into cache.", client_ip)

            metadata_stream.put(request)
            metadata_stream.client_ip = client_ip
            log.debug("Put new metadata from %s into cache.", client_ip)
        # put metadata into data queue
        metadata = metadata_stream.get()
        self._cache_store.put_data(metadata)
        log.debug("Send the reply to %s.", client_ip)
        return reply
Example #24
0
    def _deal_with_left_continue_node(self, node_name):
        """
        Construct run command with left continue nodes.

        Args:
            node_name (str): The target node name.

        Returns:
            Union[None, Event], the run command event.
        """
        cur_full_name = self._cache_store.get_stream_handler(Streams.METADATA).full_name
        if cur_full_name == node_name:
            log.info("Execute to target node: %s", node_name)
            self._old_run_cmd.clear()
            return None
        event = get_ack_reply()
        event.run_cmd.run_level = 'node'
        event.run_cmd.node_name = ''
        log.debug("Send old node RunCMD, cur node: %s, target node: %s", cur_full_name, node_name)
        return event
Example #25
0
    def _deal_with_left_continue_step(self, left_step_count):
        """
        Construct run command with left continue step count.

        Args:
            left_step_count (int): The count of left steps to be executed.

        Returns:
            Event, the run command event.
        """
        event = get_ack_reply()
        event.run_cmd.run_steps = 1
        event.run_cmd.run_level = 'step'
        left_step_count = left_step_count - 1 if left_step_count > 0 else -1
        if not left_step_count:
            self._old_run_cmd.clear()
        else:
            self._old_run_cmd['left_step_count'] = left_step_count
        log.debug("Send old step RunCMD. Left step count: %s", left_step_count)
        return event
Example #26
0
    def _construct_run_event(self, params):
        """
        Construct run cmd from input control params.

        Args:
            params (dict): The control params.

                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.

                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.

                - full_name (str): Specify the name of the node. Used when `level` is `node`.

        Returns:
            EventReply, control event with run command.
        """
        level = params.get('level', 'step')
        event = get_ack_reply()
        if level == 'step':
            steps = params.get('steps')
            if not steps:
                steps = 1
            run_cmd = RunCMD(run_level='step', run_steps=steps)
        elif level == 'node':
            self._validate_node_type(params.get('name'))
            name = self.cache_store.get_stream_handler(
                Streams.GRAPH).get_full_name(params['name'])
            if not name:
                name = ''
            run_cmd = RunCMD(run_level='node', node_name=name)
        else:
            log.error(
                "Invalid Value. `level` should be `step` or `node`. Got %s",
                level)
            raise DebuggerParamValueError("level` should be `step` or `node`")

        event.run_cmd.CopyFrom(run_cmd)
        log.debug("Construct run event. %s", event)
        return event
Example #27
0
    def SendMetadata(self, request, context):
        """Send metadata into DebuggerCache."""
        log.info("Received Metadata.")
        if self._status != ServerStatus.PENDING:
            log.info("Re-initialize cache store when new session comes.")
            self.init()

        client_ip = context.peer().split(':', 1)[-1]
        metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
        if request.training_done:
            log.info("The training from %s has finished.", client_ip)
        else:
            metadata_stream.put(request)
            metadata_stream.client_ip = client_ip
            log.info("Put new metadata from %s into cache.", client_ip)
        # put metadata into data queue
        metadata = metadata_stream.get()
        self._cache_store.put_data(metadata)
        reply = get_ack_reply()
        log.info("Send the reply to %s.", client_ip)
        return reply
    def _construct_run_event(self, params):
        """
        Construct run cmd from input control params.

        Args:
            params (dict): The control params.

                - level (str): The control granularity, `node`, `step` or `recheck` level.
                    Default: `step`.
                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.
                - name (str): Specify the name of the node. Used when `level` is `node`.
                - graph_name (str): The graph name.

        Returns:
            EventReply, control event with run command.
        """
        level = params.get('level', 'step')
        # construct run command events
        event = get_ack_reply()
        if level == 'step':
            steps = params.get('steps', 1)
            run_cmd = RunCMD(run_level='step', run_steps=steps)
        elif level == 'node':
            name = params.get('name', '')
            graph_name = params.get('graph_name')
            if name:
                rank_id = params.get('rank_id', 0)
                name = self._multi_card_graph_stream.get_graph_handler_by_rank_id(
                    rank_id).get_full_name(name, graph_name)
            run_cmd = RunCMD(run_level='node', node_name=name)
        else:
            run_cmd = RunCMD(run_level='recheck')

        event.run_cmd.CopyFrom(run_cmd)
        log.debug("Construct run event. %s", event)
        return event
 def get_set_cmd():
     """Get set command"""
     event = get_ack_reply()
     event.set_cmd.CopyFrom(SetCMD(id=1, watch_condition=1))
     return event
 def get_exit_cmd():
     """Get exit command."""
     event = get_ack_reply()
     event.exit = True
     return event