def test_all_parameters_requires_gradient(self, filter_requires_grad): pgroups = [{ 'params': [torch.zeros(1, requires_grad=True), torch.zeros(1, requires_grad=True)], 'lr': 0.1 }, { 'params': [torch.zeros(1, requires_grad=True)] }] filter_pgroups = list(filter_requires_grad(pgroups)) assert len(filter_pgroups) == 2 assert len(list(filter_pgroups[0]['params'])) == 2 assert len(list((filter_pgroups[1]['params']))) == 1 assert filter_pgroups[0]['lr'] == 0.1
def test_does_not_drop_group_when_requires_grad_is_false( self, filter_requires_grad): pgroups = [{ 'params': [ torch.zeros(1, requires_grad=False), torch.zeros(1, requires_grad=False) ], 'lr':0.1 }, { 'params': [torch.zeros(1, requires_grad=False)] }] filter_pgroups = list(filter_requires_grad(pgroups)) assert len(filter_pgroups) == 2 assert len(list(filter_pgroups[0]['params'])) == 0 assert len(list(filter_pgroups[1]['params'])) == 0 assert filter_pgroups[0]['lr'] == 0.1
def test_some_params_requires_gradient(self, filter_requires_grad): pgroups = [{ 'params': [ torch.zeros(1, requires_grad=True), torch.zeros(1, requires_grad=False) ], 'lr': 0.1 }, { 'params': [torch.zeros(1, requires_grad=False)] }] with pytest.warns(DeprecationWarning): filter_pgroups = list(filter_requires_grad(pgroups)) assert len(filter_pgroups) == 2 assert len(list(filter_pgroups[0]['params'])) == 1 assert not list(filter_pgroups[1]['params']) assert filter_pgroups[0]['lr'] == 0.1
def test_does_not_drop_group_when_requires_grad_is_false( self, filter_requires_grad): pgroups = [{ 'params': [ torch.zeros(1, requires_grad=False), torch.zeros(1, requires_grad=False) ], 'lr': 0.1 }, { 'params': [torch.zeros(1, requires_grad=False)] }] with pytest.warns(DeprecationWarning): filter_pgroups = list(filter_requires_grad(pgroups)) assert len(filter_pgroups) == 2 assert not list(filter_pgroups[0]['params']) assert not list(filter_pgroups[1]['params']) assert filter_pgroups[0]['lr'] == 0.1