def test_optimizer_feed(self, resnet18): """Feed grouped parameters to optimizer, see what happens""" lrs = [0.01, 0.001] momentums = [0.02, 0.002] weight_decays = [0.03, 0.003] param_groups = helper.group_parameters( resnet18, [['conv1.*weight'], ['downsample.*.weight']], lrs, momentums, weight_decays) optimizer = optim.SGD(param_groups) optimizer.step()
def test_raises(self): with pytest.raises(TypeError, match="must be list of list of patterns"): helper.group_parameters(None, [['downsample.1.weight'], 'conv1']) with pytest.raises(TypeError, match="must match"): helper.group_parameters(None, [['downsample.1.weight'], ['conv1']], lrs=[0.1]) with pytest.raises(TypeError, match="must match"): helper.group_parameters(None, [['downsample.1.weight'], ['conv1']], momentums=[0.1]) with pytest.raises(TypeError, match="must match"): helper.group_parameters(None, [['downsample.1.weight'], ['conv1']], weight_decays=[0.1])
def test_lr_momentum_decay(self, resnet18): lrs = [0.01, 0.001] momentums = [0.02, 0.002] weight_decays = [0.03, 0.003] param_groups = helper.group_parameters( resnet18, [['conv1.*weight'], ['downsample.*.weight']], lrs, momentums, weight_decays) assert param_groups[0]['lr'] == lrs[0] assert param_groups[1]['lr'] == lrs[1] assert param_groups[0]['momentum'] == momentums[0] assert param_groups[1]['momentum'] == momentums[1] assert param_groups[0]['weight_decay'] == weight_decays[0] assert param_groups[1]['weight_decay'] == weight_decays[1]
def test_single_key(self, resnet18): param_groups = helper.group_parameters(resnet18, [['downsample.1.weight']]) assert len(list(param_groups[0]['params'])) == 3