Example #1
0
    def test_load_state_dict(self, optimizer: OptimizerAdapter, model,
                             optimizer_state_dict):
        # test when the optimizer._optimizer is None and self._state_dict is not None
        # Firstly load state dict for optimizer, to make optimizer._state_dict not None
        assert optimizer._optimizer is None
        optimizer.load_state_dict(state_dict=optimizer_state_dict)
        assert optimizer.state_dict() == optimizer_state_dict

        # initialize a optimizer._optimizer,
        # it would load the state dict from optimizer._state_dict for optimizer._optimizer
        assert optimizer._optimizer is None and optimizer._state_dict is not None
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))

        # check if optimizer._optimizer.state_dict() same with the optimizer_state_dict
        assert (
            optimizer._optimizer.state_dict()["state"][0]["momentum_buffer"] ==
            optimizer.state_dict()["state"][0]["momentum_buffer"]).all()
        assert (optimizer.state_dict()["state"][0]["momentum_buffer"] ==
                optimizer_state_dict["state"][0]["momentum_buffer"]).all()
        assert optimizer._state_dict is None
Example #2
0
    def test_optimizer_state_recovery(self, data_batch, model,
                                      optimizer: OptimizerAdapter):
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))

        x, y = data_batch
        loss_fn = torch.nn.MSELoss(reduction='sum')

        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

        optimizer_state = deepcopy(optimizer.state_dict())
        optimizer.register_model_params(model_params=dict(
            model.named_parameters()),
                                        restore_state=False)
        assert optimizer.state_dict() != optimizer_state
        optimizer.load_state_dict(optimizer_state)
        assert (optimizer.state_dict()["state"][0]["momentum_buffer"] ==
                optimizer_state["state"][0]["momentum_buffer"]).all()