def test_lite_optimizer_state_dict(): """Test that the LiteOptimizer calls into the strategy to collect the state.""" optimizer = Mock() strategy = Mock() lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.state_dict() strategy.optimizer_state.assert_called_with(optimizer)
def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer.""" optimizer_cls = torch.optim.SGD optimizer = Mock(spec=optimizer_cls) lite_optimizer = _LiteOptimizer(optimizer, Mock()) assert lite_optimizer.optimizer is optimizer assert isinstance(lite_optimizer, optimizer_cls)
def setup( self, model: nn.Module, *optimizers: Optimizer, move_to_device: bool = True, ) -> Any: # no specific return because the way we want our API to look does not play well with mypy """Setup a model and its optimizers for accelerated training. Args: model: A model to setup *optimizers: The optimizer(s) to setup (no optimizers is also possible) move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually. Returns: The tuple of the wrapped model and list of optimizers, in the same order they were passed in. """ self._validate_setup(model, optimizers) if move_to_device: model = self._move_model_to_device(model=model, optimizers=list(optimizers)) # Let accelerator/plugin wrap and connect the models and optimizers model, optimizers = self._strategy._setup_model_and_optimizers(model, list(optimizers)) model = _LiteModule(model, self._precision_plugin) optimizers = [_LiteOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers] self._models_setup += 1 if optimizers: # join both types in a list for API convenience return [model] + optimizers # type: ignore return model
def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() strategy = Mock() lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.step() strategy.optimizer_step.assert_called_once() strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)