def test_get_lr_parameter_with_group(): net = LeNet5() conv_lr = 0.1 default_lr = 0.3 conv_params = list( filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list( filter(lambda x: 'conv' not in x.name, net.trainable_params())) group_params = [{ 'params': conv_params, 'lr': conv_lr }, { 'params': no_conv_params, 'lr': default_lr }] opt = SGD(group_params) assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) if 'conv' in param.name: cur_name = 'learning_rate_group_' + '0' else: cur_name = 'learning_rate_group_' + '1' assert lr.name == cur_name lr_list = opt.get_lr_parameter(conv_params) for lr, param in zip(lr_list, conv_params): assert lr.name == 'learning_rate_group_' + '0'
def test_get_lr_parameter_with_no_group(): net = LeNet5() conv_weight_decay = 0.8 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, {'params': no_conv_params}] opt = SGD(group_params) assert opt.is_group_lr is False for param in opt.parameters: lr = opt.get_lr_parameter(param) assert lr.name == opt.learning_rate.name params_error = [1, 2, 3] with pytest.raises(TypeError): opt.get_lr_parameter(params_error)
def test_get_lr_parameter_with_order_group(): net = LeNet5() conv_lr = 0.1 conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) group_params = [{'params': conv_params, 'lr': conv_lr}, {'order_params': net.trainable_params()}] opt = SGD(group_params) assert opt.is_group_lr is True for param in opt.parameters: lr = opt.get_lr_parameter(param) assert lr.name == 'lr_' + param.name
def test_order_params_2(): net = LeNet5() conv_weight_decay = 0.01 fc1_lr = (0.5, 0.4, 0.3) default_lr = 0.1 default_wd = 0.0 conv_params = list( filter(lambda x: 'conv' in x.name, net.trainable_params())) fc1_params = list(filter(lambda x: 'fc1' in x.name, net.trainable_params())) group_params = [{ 'params': fc1_params, 'lr': fc1_lr }, { 'params': conv_params, 'weight_decay': conv_weight_decay }, { 'order_params': fc1_params + conv_params }] opt = SGD(group_params, learning_rate=default_lr, weight_decay=default_wd) assert opt.is_group is True assert opt.is_group_lr is True assert opt.is_group_params_ordered is True all_lr = opt.get_lr_parameter(fc1_params + conv_params) for weight_decay, decay_flags, lr, param, order_param in zip( opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params + conv_params): if 'conv' in param.name: assert np.all(lr.data.asnumpy() == Tensor( np.array([default_lr] * 3), mstype.float32).asnumpy()) assert weight_decay == conv_weight_decay assert decay_flags is True elif 'fc1' in param.name: assert np.all( lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy()) assert weight_decay == default_wd assert decay_flags is False else: assert np.all(lr.data.asnumpy() == Tensor( np.array([default_lr] * 3), mstype.float32).asnumpy()) assert weight_decay == default_wd assert decay_flags is False assert param.name == order_param.name if 'conv' in param.name: assert lr.name == 'learning_rate' elif 'fc1' in param.name: assert lr.name == 'learning_rate_group_' + '0'