Exemple #1
0
class TestTrainingControlOperator:
    """Test debugger server."""

    @classmethod
    def setup_class(cls):
        """Initialize for test class."""
        cls._server = None

    def setup_method(self):
        """Prepare debugger server object."""
        cache_store = DebuggerCache()
        cache_store.initialize()
        self._server = TrainingControlOperator(cache_store)

    @mock.patch.object(GraphHandler, 'get_node_type')
    def test_validate_leaf_name(self, *args):
        """Test validate leaf name."""
        args[0].return_value = 'name_scope'
        with pytest.raises(DebuggerParamValueError, match='Invalid leaf node name.'):
            self._server._validate_continue_node_name(node_name='mock_node_name', graph_name='mock_graph_name')

    @pytest.mark.parametrize('mode, cur_state, state', [
        ('continue', 'waiting', 'sending'),
        ('pause', 'running', 'sending'),
        ('terminate', 'waiting', 'sending')])
    def test_control(self, mode, cur_state, state):
        """Test control request."""
        with mock.patch.object(MetadataHandler, 'state', cur_state):
            res = self._server.control(mode=mode, params={})
            assert res == {'metadata': {'enable_recheck': False, 'state': state}}

    def test_construct_run_event(self):
        """Test construct run event."""
        res = self._server._construct_run_event({'level': 'node'})
        assert res.run_cmd == RunCMD(run_level='node', node_name='')
    def recheck(self):
        """
        Recheck all watchpoints.

        Returns:
            dict, metadata info.
        """
        return TrainingControlOperator(self.cache_store).recheck()
    def control(self, params=None):
        """
        Control the training process.

        Args:
            params (dict): The control params.

                - mode (str): Acceptable control command, including `continue`,
                    `pause` and `terminate`.
                - level (str): The control granularity, `node` level or `step` level.
                    Default: `step`.
                - steps (int): Specify the steps that training should run.
                    Used when `level` is `step`.
                - name (str): Specify the name of the node. Used when `level` is `node`.
                - graph_name (str): The graph name.

        Returns:
            dict, the response.
        """
        log.info("Receive control request: %s.", params)
        mode = params.pop('mode', None) if params else None
        training_controller = TrainingControlOperator(self.cache_store)
        training_controller.validate_mode(mode)
        return training_controller.control(mode, params)
Exemple #4
0
 def setup_method(self):
     """Prepare debugger server object."""
     cache_store = DebuggerCache()
     cache_store.initialize()
     self._server = TrainingControlOperator(cache_store)