Пример #1
0
 def test_optimizer_feed(self, resnet18):
     """Feed grouped parameters to optimizer, see what happens"""
     lrs = [0.01, 0.001]
     momentums = [0.02, 0.002]
     weight_decays = [0.03, 0.003]
     param_groups = helper.group_parameters(
         resnet18, [['conv1.*weight'], ['downsample.*.weight']], lrs,
         momentums, weight_decays)
     optimizer = optim.SGD(param_groups)
     optimizer.step()
Пример #2
0
    def test_raises(self):
        with pytest.raises(TypeError,
                           match="must be list of list of patterns"):
            helper.group_parameters(None, [['downsample.1.weight'], 'conv1'])

        with pytest.raises(TypeError, match="must match"):
            helper.group_parameters(None, [['downsample.1.weight'], ['conv1']],
                                    lrs=[0.1])

        with pytest.raises(TypeError, match="must match"):
            helper.group_parameters(None, [['downsample.1.weight'], ['conv1']],
                                    momentums=[0.1])

        with pytest.raises(TypeError, match="must match"):
            helper.group_parameters(None, [['downsample.1.weight'], ['conv1']],
                                    weight_decays=[0.1])
Пример #3
0
    def test_lr_momentum_decay(self, resnet18):
        lrs = [0.01, 0.001]
        momentums = [0.02, 0.002]
        weight_decays = [0.03, 0.003]
        param_groups = helper.group_parameters(
            resnet18, [['conv1.*weight'], ['downsample.*.weight']], lrs,
            momentums, weight_decays)

        assert param_groups[0]['lr'] == lrs[0]
        assert param_groups[1]['lr'] == lrs[1]
        assert param_groups[0]['momentum'] == momentums[0]
        assert param_groups[1]['momentum'] == momentums[1]
        assert param_groups[0]['weight_decay'] == weight_decays[0]
        assert param_groups[1]['weight_decay'] == weight_decays[1]
Пример #4
0
 def test_single_key(self, resnet18):
     param_groups = helper.group_parameters(resnet18,
                                            [['downsample.1.weight']])
     assert len(list(param_groups[0]['params'])) == 3