def test_lr_values(self, input_param, expected_values, expected_groups): device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups(network=net, **input_param) optimizer = torch.optim.Adam(params, 100) for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)): torch.testing.assert_allclose(param_group["lr"], value) n = [len(p["params"]) for p in params] self.assertListEqual(n, expected_groups)
def test_wrong(self): """overlapped""" device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) params = generate_param_groups( network=net, layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]], match_types="select", lr_values=0.1, ) with self.assertRaises(ValueError): torch.optim.Adam(params, 100)
def test_lr_values(self, input_param, expected_values): device = "cuda" if torch.cuda.is_available() else "cpu" net = Unet( dimensions=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1, ).to(device) params = generate_param_groups(network=net, **input_param) optimizer = torch.optim.Adam(params, 100) for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)): torch.testing.assert_allclose(param_group["lr"], value) n = [len(p["params"]) for p in params] assert sum(n) == 26 or all( n), "should have either full model or non-empty subsets."