Exemplo n.º 1
0
    def prepare_optimizer(self, optimizer, model, loss=None):
        bn_params, other_params = split_batchnorm_params(model)
        if loss is not None:
            bn_params_loss, params_loss = split_batchnorm_params(loss)
            bn_params = bn_params + bn_params_loss
            other_params = other_params + params_loss

        bn_schedulers = self.optimizer_schedulers.copy()
        if not self.bn_weight_decay:
            bn_schedulers["weight_decay"] = 0

        param_groups = [{"params": other_params, **self.optimizer_schedulers}]
        if len(bn_params) > 0:
            param_groups.append({"params": bn_params, **bn_schedulers})
        self.optimizer.set_param_groups(param_groups)
Exemplo n.º 2
0
    def prepare_optimizer(self, optimizer, model, loss=None):
        if not self.bn_weight_decay:
            bn_params, params = split_batchnorm_params(model)
            if loss is not None:
                bn_params_loss, params_loss = split_batchnorm_params(loss)
                bn_params = bn_params + bn_params_loss
                params = params + params_loss

            frozen_param_groups = (
                {"params": bn_params, "weight_decay": 0} if len(bn_params) > 0 else None
            )
            param_groups = {"params": params}
        else:
            frozen_param_groups = None
            params = model.parameters()
            if loss is not None:
                params = chain(params, loss.parameters())

            param_groups = {"params": list(params)}

        self.optimizer.set_param_groups(
            param_groups=param_groups, frozen_param_groups=frozen_param_groups
        )
    def test_batchnorm_weight_decay(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(2, 3)
                self.relu = nn.ReLU()
                self.bn = nn.BatchNorm1d(3)

            def forward(self, x):
                return self.bn(self.relu(self.lin(x)))

        torch.manual_seed(1)
        model = MyModel()

        opt = build_optimizer(self._get_config())
        bn_params, lin_params = split_batchnorm_params(model)

        lin_param_before = model.lin.weight.detach().clone()
        bn_param_before = model.bn.weight.detach().clone()

        with torch.enable_grad():
            x = torch.tensor([[1.0, 1.0], [1.0, 2.0]])
            out = model(x).pow(2).sum()
            out.backward()

        opt.set_param_groups([
            {
                "params": lin_params,
                "lr": LinearParamScheduler(1, 2),
                "weight_decay": 0.5,
            },
            {
                "params": bn_params,
                "lr": 0,
                "weight_decay": 0
            },
        ])

        opt.step(where=0.5)

        # Make sure the linear parameters are trained but not the batch norm
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))

        opt.step(where=0.5)

        # Same, but after another step and triggering the lr scheduler
        self.assertFalse(torch.allclose(model.lin.weight, lin_param_before))
        self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))
    def test_split_batchnorm_params(self):
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(2, 3, bias=False)
                self.relu = nn.ReLU()
                self.bn = nn.BatchNorm1d(3)

            def forward(self, x):
                return self.bn(self.relu(self.lin(x)))

        torch.manual_seed(1)
        model = MyModel()

        bn_params, lin_params = split_batchnorm_params(model)

        self.assertEquals(len(bn_params), 2)
        self.assertEquals(len(lin_params), 1)

        self.assertTrue(torch.allclose(bn_params[0], model.bn.weight))
        self.assertTrue(torch.allclose(bn_params[1], model.bn.bias))
        self.assertTrue(torch.allclose(lin_params[0], model.lin.weight))