示例#1
0
    def test_all_parameters_requires_gradient(self, filter_requires_grad):
        pgroups = [{
            'params': [torch.zeros(1, requires_grad=True),
                       torch.zeros(1, requires_grad=True)],
            'lr': 0.1
        }, {
            'params': [torch.zeros(1, requires_grad=True)]
        }]

        filter_pgroups = list(filter_requires_grad(pgroups))
        assert len(filter_pgroups) == 2
        assert len(list(filter_pgroups[0]['params'])) == 2
        assert len(list((filter_pgroups[1]['params']))) == 1

        assert filter_pgroups[0]['lr'] == 0.1
示例#2
0
    def test_does_not_drop_group_when_requires_grad_is_false(
            self, filter_requires_grad):
        pgroups = [{
            'params': [
                torch.zeros(1, requires_grad=False),
                torch.zeros(1, requires_grad=False)
            ], 'lr':0.1
        }, {
            'params': [torch.zeros(1, requires_grad=False)]
        }]

        filter_pgroups = list(filter_requires_grad(pgroups))
        assert len(filter_pgroups) == 2
        assert len(list(filter_pgroups[0]['params'])) == 0
        assert len(list(filter_pgroups[1]['params'])) == 0

        assert filter_pgroups[0]['lr'] == 0.1
示例#3
0
    def test_some_params_requires_gradient(self, filter_requires_grad):
        pgroups = [{
            'params': [
                torch.zeros(1, requires_grad=True),
                torch.zeros(1, requires_grad=False)
            ], 'lr': 0.1
        }, {
            'params': [torch.zeros(1, requires_grad=False)]
        }]

        with pytest.warns(DeprecationWarning):
            filter_pgroups = list(filter_requires_grad(pgroups))
        assert len(filter_pgroups) == 2
        assert len(list(filter_pgroups[0]['params'])) == 1
        assert not list(filter_pgroups[1]['params'])

        assert filter_pgroups[0]['lr'] == 0.1
示例#4
0
    def test_does_not_drop_group_when_requires_grad_is_false(
            self, filter_requires_grad):
        pgroups = [{
            'params': [
                torch.zeros(1, requires_grad=False),
                torch.zeros(1, requires_grad=False)
            ], 'lr': 0.1
        }, {
            'params': [torch.zeros(1, requires_grad=False)]
        }]

        with pytest.warns(DeprecationWarning):
            filter_pgroups = list(filter_requires_grad(pgroups))
        assert len(filter_pgroups) == 2
        assert not list(filter_pgroups[0]['params'])
        assert not list(filter_pgroups[1]['params'])

        assert filter_pgroups[0]['lr'] == 0.1