コード例 #1
0
 def test_no_wd_param_group(self):
     model = torch.nn.Sequential(TstLayer(128), torch.nn.ModuleList([TstLayer(128), TstLayer(128)]))
     trainer = Trainer(model=model)
     trainer.create_optimizer_and_scheduler(10)
     # fmt: off
     wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight']
     # fmt: on
     wd_params = [p for n, p in model.named_parameters() if n in wd_names]
     no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
     self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
     self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)