def _initialize_debugger_server():
    """Initialize a debugger server instance."""
    enable_debugger = settings.ENABLE_DEBUGGER if hasattr(
        settings, 'ENABLE_DEBUGGER') else False
    server = None
    if enable_debugger:
        server = DebuggerServer()
    return server
Exemple #2
0
def _initialize_debugger_server():
    """Initialize a debugger server instance."""
    port = settings.DEBUGGER_PORT if hasattr(settings,
                                             'DEBUGGER_PORT') else None
    enable_debugger = settings.ENABLE_DEBUGGER if hasattr(
        settings, 'ENABLE_DEBUGGER') else False
    server = None
    if port and enable_debugger:
        server = DebuggerServer(port)
    return server
 def setup_method(self):
     """Prepare debugger server object."""
     self._server = DebuggerServer()
class TestDebuggerServer:
    """Test debugger server."""
    @classmethod
    def setup_class(cls):
        """Initialize for test class."""
        cls._server = None

    def setup_method(self):
        """Prepare debugger server object."""
        self._server = DebuggerServer()

    @mock.patch.object(signal, 'signal')
    @mock.patch.object(Thread, 'join')
    @mock.patch.object(Thread, 'start')
    @mock.patch.object(grpc_server_base, 'add_EventListenerServicer_to_server')
    @mock.patch.object(grpc, 'server')
    def test_stop_server(self, *args):
        """Test stop debugger server."""
        mock_grpc_server_manager = MagicMock()
        args[0].return_value = mock_grpc_server_manager
        self._server.start()
        self._server._stop_handler(MagicMock(), MagicMock())
        assert self._server.back_server is not None
        assert self._server.grpc_server_manager == mock_grpc_server_manager

    @mock.patch.object(DebuggerCache, 'get_data')
    def test_poll_data(self, *args):
        """Test poll data request."""
        mock_data = {'pos': 'mock_data'}
        args[0].return_value = mock_data
        res = self._server.poll_data('0')
        assert res == mock_data

    def test_poll_data_with_exept(self):
        """Test poll data with wrong input."""
        with pytest.raises(DebuggerParamValueError,
                           match='Pos should be string.'):
            self._server.poll_data(1)

    @mock.patch.object(GraphHandler, 'search_nodes')
    def test_search(self, *args):
        """Test search node."""
        mock_graph = {'nodes': ['mock_nodes']}
        args[0].return_value = mock_graph
        res = self._server.search({'name': 'mock_name'})
        assert res == mock_graph

    def test_tensor_comparision_with_wrong_status(self):
        """Test tensor comparison with wrong status."""
        with pytest.raises(
                DebuggerCompareTensorError,
                match=
                'Failed to compare tensors as the MindSpore is not in waiting state.'
        ):
            self._server.tensor_comparisons(name='mock_node_name:0',
                                            shape='[:, :]')

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    @mock.patch.object(GraphHandler, 'get_node_type')
    @mock.patch.object(GraphHandler, 'get_graph_id_by_name')
    @mock.patch.object(GraphHandler,
                       'get_full_name',
                       return_value='mock_node_name')
    def test_tensor_comparision_with_wrong_type(self, *args):
        """Test tensor comparison with wrong type."""
        args[1].return_value = 'name_scope'
        with pytest.raises(DebuggerParamValueError,
                           match='The node type must be parameter'):
            self._server.tensor_comparisons(name='mock_node_name:0',
                                            shape='[:, :]')

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    @mock.patch.object(GraphHandler, 'get_graph_id_by_name')
    @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
    @mock.patch.object(GraphHandler,
                       'get_full_name',
                       return_value='mock_node_name')
    @mock.patch.object(TensorHandler, 'get_tensors_diff')
    def test_tensor_comparision(self, *args):
        """Test tensor comparison"""
        mock_diff_res = {'tensor_value': {}}
        args[0].return_value = mock_diff_res
        res = self._server.tensor_comparisons('mock_node_name:0', '[:, :]')
        assert res == mock_diff_res

    def test_retrieve_with_pending(self):
        """Test retrieve request in pending status."""
        res = self._server.retrieve(mode='all')
        assert res.get('metadata', {}).get('state') == 'pending'

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    def test_retrieve_all(self):
        """Test retrieve request."""
        res = self._server.retrieve(mode='all')
        compare_debugger_result_with_file(res,
                                          'debugger_server/retrieve_all.json')

    def test_retrieve_with_invalid_mode(self):
        """Test retrieve with invalid mode."""
        with pytest.raises(DebuggerParamValueError, match='Invalid mode.'):
            self._server.retrieve(mode='invalid_mode')

    @mock.patch.object(GraphHandler, 'get')
    @mock.patch.object(GraphHandler,
                       'get_node_type',
                       return_value='name_scope')
    @mock.patch.object(GraphHandler,
                       'get_full_name',
                       return_value='mock_node_name')
    def test_retrieve_node(self, *args):
        """Test retrieve node information."""
        mock_graph = {'graph': {}}
        args[2].return_value = mock_graph
        res = self._server._retrieve_node({'name': 'mock_node_name'})
        assert res == mock_graph

    def test_retrieve_tensor_history_with_pending(self):
        """Test retrieve request in pending status."""
        res = self._server.retrieve_tensor_history('mock_node_name')
        assert res.get('metadata', {}).get('state') == 'pending'

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    @mock.patch.object(GraphHandler, 'get_tensor_history')
    @mock.patch.object(GraphHandler, 'get_node_type', return_value='Parameter')
    def test_retrieve_tensor_history(self, *args):
        """Test retrieve tensor history."""
        args[1].return_value = mock_tensor_history()
        res = self._server.retrieve_tensor_history('mock_node_name')
        compare_debugger_result_with_file(
            res, 'debugger_server/retrieve_tensor_history.json')

    @mock.patch.object(TensorHandler, 'get')
    @mock.patch.object(DebuggerServer, '_get_tensor_name_and_type_by_ui_name')
    def test_retrieve_tensor_value(self, *args):
        """Test retrieve tensor value."""
        mock_tensor_value = {'tensor_value': {'name': 'mock_name:0'}}
        args[0].return_value = ('Parameter', 'mock_node_name')
        args[1].return_value = mock_tensor_value
        res = self._server.retrieve_tensor_value('mock_name:0', 'data',
                                                 '[:, :]')
        assert res == mock_tensor_value

    @mock.patch.object(WatchpointHandler, 'get')
    def test_retrieve_watchpoints(self, *args):
        """Test retrieve watchpoints."""
        mock_watchpoint = {'watch_points': {}}
        args[0].return_value = mock_watchpoint
        res = self._server._retrieve_watchpoint({})
        assert res == mock_watchpoint

    @mock.patch.object(DebuggerServer, '_retrieve_node')
    def test_retrieve_watchpoint(self, *args):
        """Test retrieve single watchpoint."""
        mock_watchpoint = {'nodes': {}}
        args[0].return_value = mock_watchpoint
        res = self._server._retrieve_watchpoint({'watch_point_id': 1})
        assert res == mock_watchpoint

    def test_create_watchpoint_with_wrong_state(self):
        """Test create watchpoint with wrong state."""
        with pytest.raises(DebuggerCreateWatchPointError,
                           match='Failed to create watchpoint'):
            self._server.create_watchpoint({'watch_condition': {'id': 'inf'}})

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    @mock.patch.object(MetadataHandler, 'backend', 'GPU')
    @mock.patch.object(GraphHandler,
                       'get_node_basic_info',
                       return_value=MagicMock())
    @mock.patch.object(GraphHandler,
                       'get_node_type',
                       return_value='aggregation_scope')
    @mock.patch.object(watchpoint_operator,
                       'get_basic_node_info',
                       return_value=MagicMock())
    @mock.patch.object(WatchpointHandler, 'create_watchpoint')
    def test_create_watchpoint(self, *args):
        """Test create watchpoint."""
        args[0].return_value = 1
        res = self._server.create_watchpoint({
            'watch_condition': {
                'id': 'tensor_too_large',
                'params': [{
                    'name': 'max_gt',
                    'value': 1.0
                }]
            },
            'watch_nodes': ['watch_node_name']
        })
        assert res == {
            'id': 1,
            'metadata': {
                'enable_recheck': False,
                'state': 'waiting'
            }
        }

    @mock.patch.object(MetadataHandler, 'state', 'waiting')
    @mock.patch.object(GraphHandler,
                       'validate_graph_name',
                       return_value='kernel_graph_0')
    @mock.patch.object(GraphHandler, 'get_node_basic_info')
    @mock.patch.object(GraphHandler, 'search_nodes')
    @mock.patch.object(WatchpointHandler, 'validate_watchpoint_id')
    @mock.patch.object(WatchpointHandler, 'update_watchpoint')
    def test_update_watchpoint(self, *args):
        """Test update watchpoint."""
        args[2].return_value = {'nodes': [{'name': 'mock_name', 'nodes': []}]}
        res = self._server.update_watchpoint({
            'watch_point_id': 1,
            'watch_nodes': ['search_name'],
            'mode': 1,
            'search_pattern': {
                'name': 'search_name'
            },
            'graph_name': 'kernel_graph_0'
        })
        assert res == {
            'metadata': {
                'enable_recheck': False,
                'state': 'waiting'
            }
        }

    def test_delete_watchpoint_with_wrong_state(self):
        """Test delete watchpoint with wrong state."""
        with pytest.raises(DebuggerDeleteWatchPointError,
                           match='Failed to delete watchpoint'):
            self._server.delete_watchpoint(watch_point_id=1)

    @mock.patch.object(MetadataHandler, 'enable_recheck', True)
    @mock.patch.object(WatchpointHandler, 'is_recheckable', return_value=True)
    @mock.patch.object(WatchpointHandler, 'delete_watchpoint')
    def test_delete_watchpoint(self, *args):
        """Test delete watchpoint with wrong state."""
        self._server.cache_store.get_stream_handler(
            Streams.METADATA).state = 'waiting'
        args[0].return_value = None
        res = self._server.delete_watchpoint(1)
        assert res == {
            'metadata': {
                'enable_recheck': True,
                'state': 'waiting'
            }
        }