def test_no_wd_param_group(self): model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)])) trainer = Trainer(model=model) trainer.create_optimizer_and_scheduler(10) # fmt: off wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight'] # fmt: on wd_params = [p for n, p in model.named_parameters() if n in wd_names] no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names] self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params) self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)