Example #1
0
 def test_register_model_params_without_restore_state(
         self, optimizer: OptimizerAdapter, model, false_restore_state):
     # If not restore_state,
     # reinitialize the optimizer._optimizer with given model_params,
     # no matter the optimizer._optimizer is None or not
     optimizer.register_model_params(model_params=dict(
         model.named_parameters()),
                                     restore_state=false_restore_state)
     assert len(optimizer._optimizer.param_groups) > 0
Example #2
0
    def optimizer_state_dict(self, model, data_batch) -> Dict[str, Any]:
        # To have a state dict
        optimizer = OptimizerAdapter(SGD, {"lr": 1.0, "momentum": 0.9})
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))
        x, y = data_batch
        loss_fn = torch.nn.MSELoss(reduction='sum')

        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

        optimizer_state_1 = deepcopy(optimizer.state_dict())

        return optimizer_state_1
Example #3
0
    def test_load_state_dict(self, optimizer: OptimizerAdapter, model,
                             optimizer_state_dict):
        # test when the optimizer._optimizer is None and self._state_dict is not None
        # Firstly load state dict for optimizer, to make optimizer._state_dict not None
        assert optimizer._optimizer is None
        optimizer.load_state_dict(state_dict=optimizer_state_dict)
        assert optimizer.state_dict() == optimizer_state_dict

        # initialize a optimizer._optimizer,
        # it would load the state dict from optimizer._state_dict for optimizer._optimizer
        assert optimizer._optimizer is None and optimizer._state_dict is not None
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))

        # check if optimizer._optimizer.state_dict() same with the optimizer_state_dict
        assert (
            optimizer._optimizer.state_dict()["state"][0]["momentum_buffer"] ==
            optimizer.state_dict()["state"][0]["momentum_buffer"]).all()
        assert (optimizer.state_dict()["state"][0]["momentum_buffer"] ==
                optimizer_state_dict["state"][0]["momentum_buffer"]).all()
        assert optimizer._state_dict is None
Example #4
0
    def test_optimizer_state_change(self, data_batch, model,
                                    optimizer: OptimizerAdapter):
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))
        optimizer_state_0 = deepcopy(optimizer.state_dict())

        x, y = data_batch
        loss_fn = torch.nn.MSELoss(reduction='sum')

        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

        optimizer_state_1 = deepcopy(optimizer.state_dict())

        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

        optimizer_state_2 = deepcopy(optimizer.state_dict())

        assert optimizer_state_0 != optimizer_state_1
        assert (optimizer_state_1["state"][0]["momentum_buffer"] !=
                optimizer_state_2["state"][0]["momentum_buffer"]).all()
Example #5
0
 def test_register_model_params(self, optimizer: OptimizerAdapter, model):
     assert optimizer._optimizer is None
     optimizer.register_model_params(
         model_params=dict(model.named_parameters()))
     assert len(optimizer._optimizer.param_groups) > 0
Example #6
0
 def optimizer(self) -> OptimizerAdapter:
     optimizer = OptimizerAdapter(SGD, {"lr": 1.0, "momentum": 0.9})
     return optimizer
Example #7
0
    def test_optimizer_state_recovery(self, data_batch, model,
                                      optimizer: OptimizerAdapter):
        optimizer.register_model_params(
            model_params=dict(model.named_parameters()))

        x, y = data_batch
        loss_fn = torch.nn.MSELoss(reduction='sum')

        optimizer.zero_grad()
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()

        optimizer_state = deepcopy(optimizer.state_dict())
        optimizer.register_model_params(model_params=dict(
            model.named_parameters()),
                                        restore_state=False)
        assert optimizer.state_dict() != optimizer_state
        optimizer.load_state_dict(optimizer_state)
        assert (optimizer.state_dict()["state"][0]["momentum_buffer"] ==
                optimizer_state["state"][0]["momentum_buffer"]).all()
Example #8
0
 def get_optimizer(cls, optimizer_key: str,
                   params: Dict) -> OptimizerAdapter:
     optimizer_class = cls.optimizer_map[optimizer_key]
     return OptimizerAdapter(optimizer_class=optimizer_class,
                             optimizer_params=params)
Example #9
0
 def train_batch(self, batch: DatasetBatch, model: NNModel, optimizer: OptimizerAdapter, device: torch.device):
     model.zero_grad()
     batch.to_device(device)
     loss = self.calc_loss(model, batch)
     loss.sum().backward()
     optimizer.step()