コード例 #1
0
def test_lite_optimizer_state_dict():
    """Test that the LiteOptimizer calls into the strategy to collect the state."""
    optimizer = Mock()
    strategy = Mock()
    lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
    lite_optimizer.state_dict()
    strategy.optimizer_state.assert_called_with(optimizer)
コード例 #2
0
def test_lite_optimizer_wraps():
    """Test that the LiteOptimizer fully wraps the optimizer."""
    optimizer_cls = torch.optim.SGD
    optimizer = Mock(spec=optimizer_cls)
    lite_optimizer = _LiteOptimizer(optimizer, Mock())
    assert lite_optimizer.optimizer is optimizer
    assert isinstance(lite_optimizer, optimizer_cls)
コード例 #3
0
    def setup(
        self,
        model: nn.Module,
        *optimizers: Optimizer,
        move_to_device: bool = True,
    ) -> Any:  # no specific return because the way we want our API to look does not play well with mypy
        """Setup a model and its optimizers for accelerated training.

        Args:
            model: A model to setup
            *optimizers: The optimizer(s) to setup (no optimizers is also possible)
            move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
                and alternatively use :meth:`to_device` manually.

        Returns:
            The tuple of the wrapped model and list of optimizers, in the same order they were passed in.
        """
        self._validate_setup(model, optimizers)

        if move_to_device:
            model = self._move_model_to_device(model=model, optimizers=list(optimizers))

        # Let accelerator/plugin wrap and connect the models and optimizers
        model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers))
        model = _LiteModule(model, self._precision_plugin)
        optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]
        self._models_setup += 1
        if optimizers:
            # join both types in a list for API convenience
            return [model] + optimizers  # type: ignore
        return model
コード例 #4
0
def test_lite_optimizer_steps():
    """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
    optimizer = Mock()
    strategy = Mock()
    lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
    lite_optimizer.step()
    strategy.optimizer_step.assert_called_once()
    strategy.optimizer_step.assert_called_with(optimizer,
                                               opt_idx=0,
                                               closure=ANY,
                                               model=strategy.model)