def prepare_optimizer(self, optimizer, model, loss=None): bn_params, other_params = split_batchnorm_params(model) if loss is not None: bn_params_loss, params_loss = split_batchnorm_params(loss) bn_params = bn_params + bn_params_loss other_params = other_params + params_loss bn_schedulers = self.optimizer_schedulers.copy() if not self.bn_weight_decay: bn_schedulers["weight_decay"] = 0 param_groups = [{"params": other_params, **self.optimizer_schedulers}] if len(bn_params) > 0: param_groups.append({"params": bn_params, **bn_schedulers}) self.optimizer.set_param_groups(param_groups)
def prepare_optimizer(self, optimizer, model, loss=None): if not self.bn_weight_decay: bn_params, params = split_batchnorm_params(model) if loss is not None: bn_params_loss, params_loss = split_batchnorm_params(loss) bn_params = bn_params + bn_params_loss params = params + params_loss frozen_param_groups = ( {"params": bn_params, "weight_decay": 0} if len(bn_params) > 0 else None ) param_groups = {"params": params} else: frozen_param_groups = None params = model.parameters() if loss is not None: params = chain(params, loss.parameters()) param_groups = {"params": list(params)} self.optimizer.set_param_groups( param_groups=param_groups, frozen_param_groups=frozen_param_groups )
def test_batchnorm_weight_decay(self): class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(2, 3) self.relu = nn.ReLU() self.bn = nn.BatchNorm1d(3) def forward(self, x): return self.bn(self.relu(self.lin(x))) torch.manual_seed(1) model = MyModel() opt = build_optimizer(self._get_config()) bn_params, lin_params = split_batchnorm_params(model) lin_param_before = model.lin.weight.detach().clone() bn_param_before = model.bn.weight.detach().clone() with torch.enable_grad(): x = torch.tensor([[1.0, 1.0], [1.0, 2.0]]) out = model(x).pow(2).sum() out.backward() opt.set_param_groups([ { "params": lin_params, "lr": LinearParamScheduler(1, 2), "weight_decay": 0.5, }, { "params": bn_params, "lr": 0, "weight_decay": 0 }, ]) opt.step(where=0.5) # Make sure the linear parameters are trained but not the batch norm self.assertFalse(torch.allclose(model.lin.weight, lin_param_before)) self.assertTrue(torch.allclose(model.bn.weight, bn_param_before)) opt.step(where=0.5) # Same, but after another step and triggering the lr scheduler self.assertFalse(torch.allclose(model.lin.weight, lin_param_before)) self.assertTrue(torch.allclose(model.bn.weight, bn_param_before))
def test_split_batchnorm_params(self): class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = nn.Linear(2, 3, bias=False) self.relu = nn.ReLU() self.bn = nn.BatchNorm1d(3) def forward(self, x): return self.bn(self.relu(self.lin(x))) torch.manual_seed(1) model = MyModel() bn_params, lin_params = split_batchnorm_params(model) self.assertEquals(len(bn_params), 2) self.assertEquals(len(lin_params), 1) self.assertTrue(torch.allclose(bn_params[0], model.bn.weight)) self.assertTrue(torch.allclose(bn_params[1], model.bn.bias)) self.assertTrue(torch.allclose(lin_params[0], model.lin.weight))