def test_lr_values(self, input_param, expected_values, expected_groups):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        net = Unet(
            spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1
        ).to(device)

        params = generate_param_groups(network=net, **input_param)
        optimizer = torch.optim.Adam(params, 100)

        for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)):
            torch.testing.assert_allclose(param_group["lr"], value)

        n = [len(p["params"]) for p in params]
        self.assertListEqual(n, expected_groups)
    def test_wrong(self):
        """overlapped"""
        device = "cuda" if torch.cuda.is_available() else "cpu"
        net = Unet(
            spatial_dims=3, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1
        ).to(device)

        params = generate_param_groups(
            network=net,
            layer_matches=[lambda x: x.model[-1], lambda x: x.model[-1]],
            match_types="select",
            lr_values=0.1,
        )
        with self.assertRaises(ValueError):
            torch.optim.Adam(params, 100)
Esempio n. 3
0
    def test_lr_values(self, input_param, expected_values):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        net = Unet(
            dimensions=3,
            in_channels=1,
            out_channels=3,
            channels=(16, 32, 64),
            strides=(2, 2),
            num_res_units=1,
        ).to(device)

        params = generate_param_groups(network=net, **input_param)
        optimizer = torch.optim.Adam(params, 100)

        for param_group, value in zip(optimizer.param_groups,
                                      ensure_tuple(expected_values)):
            torch.testing.assert_allclose(param_group["lr"], value)

        n = [len(p["params"]) for p in params]
        assert sum(n) == 26 or all(
            n), "should have either full model or non-empty subsets."