コード例 #1
0
def test_get_lr_parameter_with_group():
    net = LeNet5()
    conv_lr = 0.1
    default_lr = 0.3
    conv_params = list(
        filter(lambda x: 'conv' in x.name, net.trainable_params()))
    no_conv_params = list(
        filter(lambda x: 'conv' not in x.name, net.trainable_params()))
    group_params = [{
        'params': conv_params,
        'lr': conv_lr
    }, {
        'params': no_conv_params,
        'lr': default_lr
    }]
    opt = SGD(group_params)
    assert opt.is_group_lr is True
    for param in opt.parameters:
        lr = opt.get_lr_parameter(param)
        if 'conv' in param.name:
            cur_name = 'learning_rate_group_' + '0'
        else:
            cur_name = 'learning_rate_group_' + '1'
        assert lr.name == cur_name

    lr_list = opt.get_lr_parameter(conv_params)
    for lr, param in zip(lr_list, conv_params):
        assert lr.name == 'learning_rate_group_' + '0'
コード例 #2
0
def test_get_lr_parameter_with_no_group():
    net = LeNet5()
    conv_weight_decay = 0.8

    conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
    no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
    group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay},
                    {'params': no_conv_params}]
    opt = SGD(group_params)
    assert opt.is_group_lr is False
    for param in opt.parameters:
        lr = opt.get_lr_parameter(param)
        assert lr.name == opt.learning_rate.name

    params_error = [1, 2, 3]
    with pytest.raises(TypeError):
        opt.get_lr_parameter(params_error)
コード例 #3
0
def test_get_lr_parameter_with_order_group():
    net = LeNet5()
    conv_lr = 0.1
    conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
    group_params = [{'params': conv_params, 'lr': conv_lr},
                    {'order_params': net.trainable_params()}]
    opt = SGD(group_params)
    assert opt.is_group_lr is True
    for param in opt.parameters:
        lr = opt.get_lr_parameter(param)
        assert lr.name == 'lr_' + param.name
コード例 #4
0
def test_order_params_2():
    net = LeNet5()
    conv_weight_decay = 0.01
    fc1_lr = (0.5, 0.4, 0.3)
    default_lr = 0.1
    default_wd = 0.0
    conv_params = list(
        filter(lambda x: 'conv' in x.name, net.trainable_params()))
    fc1_params = list(filter(lambda x: 'fc1' in x.name,
                             net.trainable_params()))
    group_params = [{
        'params': fc1_params,
        'lr': fc1_lr
    }, {
        'params': conv_params,
        'weight_decay': conv_weight_decay
    }, {
        'order_params': fc1_params + conv_params
    }]
    opt = SGD(group_params, learning_rate=default_lr, weight_decay=default_wd)
    assert opt.is_group is True
    assert opt.is_group_lr is True
    assert opt.is_group_params_ordered is True
    all_lr = opt.get_lr_parameter(fc1_params + conv_params)
    for weight_decay, decay_flags, lr, param, order_param in zip(
            opt.weight_decay, opt.decay_flags, all_lr, opt.parameters,
            fc1_params + conv_params):
        if 'conv' in param.name:
            assert np.all(lr.data.asnumpy() == Tensor(
                np.array([default_lr] * 3), mstype.float32).asnumpy())
            assert weight_decay == conv_weight_decay
            assert decay_flags is True
        elif 'fc1' in param.name:
            assert np.all(
                lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy())
            assert weight_decay == default_wd
            assert decay_flags is False
        else:
            assert np.all(lr.data.asnumpy() == Tensor(
                np.array([default_lr] * 3), mstype.float32).asnumpy())
            assert weight_decay == default_wd
            assert decay_flags is False

        assert param.name == order_param.name
        if 'conv' in param.name:
            assert lr.name == 'learning_rate'
        elif 'fc1' in param.name:
            assert lr.name == 'learning_rate_group_' + '0'