def test_applying_masks(prune_bn): config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8)) config['compression']['params']['prune_batch_norms'] = prune_bn config['compression']['params']['prune_first_conv'] = True config['compression']['params']['prune_last_conv'] = True pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config( config) pruned_module_info = pruning_algo.pruned_module_info pruned_modules = [minfo.module for minfo in pruned_module_info] assert len(pruned_modules) == len(nncf_modules) for module in pruned_modules: op = list(module.pre_ops.values())[0] mask = op.operand.binary_filter_pruning_mask masked_weight = apply_filter_binary_mask(mask, module.weight) masked_bias = apply_filter_binary_mask(mask, module.bias) assert torch.allclose(module.weight, masked_weight) assert torch.allclose(module.bias, masked_bias) # Have only one BN node in graph bn_module = pruned_model.bn conv_for_bn = pruned_model.conv2 bn_mask = list( conv_for_bn.pre_ops.values())[0].operand.binary_filter_pruning_mask if prune_bn: masked_bn_weight = apply_filter_binary_mask(bn_mask, bn_module.weight) masked_bn_bias = apply_filter_binary_mask(bn_mask, bn_module.bias) assert torch.allclose(bn_module.weight, masked_bn_weight) assert torch.allclose(bn_module.bias, masked_bn_bias)
def test_assert_broadcastable_mask_and_weight_shape(): nncf_module = NNCFConv2d(1, 2, 2) fill_conv_weight(nncf_module, 1) fill_bias(nncf_module, 1) mask = torch.zeros(10) with pytest.raises(RuntimeError): inplace_apply_filter_binary_mask(mask, nncf_module.weight.data) with pytest.raises(RuntimeError): apply_filter_binary_mask(mask, nncf_module.weight.data)
def test_apply_filter_binary_mask(mask, reference_weight, reference_bias): """ Test that apply_filter_binary_mask not changes the input weight and returns valid result. """ nncf_module = NNCFConv2d(1, 2, 2) fill_conv_weight(nncf_module, 1) fill_bias(nncf_module, 1) original_weight = nncf_module.weight.data.detach().clone() original_bias = nncf_module.bias.data.detach().clone() result = apply_filter_binary_mask(mask, nncf_module.weight.data) assert torch.allclose(nncf_module.weight, original_weight) assert torch.allclose(result, reference_weight) result_bias = apply_filter_binary_mask(mask, nncf_module.bias.data) assert torch.allclose(result_bias, reference_bias) assert torch.allclose(nncf_module.bias, original_bias)
def test_zeroing_gradients(zero_grad): """ Test for zeroing gradients functionality (zero_grads_for_pruned_modules in base algo) :param zero_grad: zero grad or not """ config = get_basic_pruning_config(input_sample_size=(2, 1, 8, 8)) config['compression']['params']['prune_first_conv'] = True config['compression']['params']['prune_last_conv'] = True config['compression']['params']['zero_grad'] = zero_grad pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config) assert pruning_algo.zero_grad is zero_grad pruned_module_info = pruning_algo.pruned_module_info pruned_modules = [minfo.module for minfo in pruned_module_info] device = next(pruned_model.parameters()).device data_loader = create_dataloader(config) pruning_algo.initialize(data_loader) params_to_optimize = get_parameter_groups(pruned_model, config) optimizer, lr_scheduler = make_optimizer(params_to_optimize, config) lr_scheduler.step(0) pruned_model.train() for input_, target in data_loader: input_ = input_.to(device) target = target.to(device).view(1) output = pruned_model(input_) loss = torch.sum(target.to(torch.float32) - output) optimizer.zero_grad() loss.backward() # In case of zero_grad = True gradients should be masked if zero_grad: for module in pruned_modules: op = list(module.pre_ops.values())[0] mask = op.operand.binary_filter_pruning_mask grad = module.weight.grad masked_grad = apply_filter_binary_mask(mask, grad) assert torch.allclose(masked_grad, grad)
def hook(grad, mask): mask = mask.to(grad.device) return apply_filter_binary_mask(mask, grad)