コード例 #1
0
def test_applying_masks(prune_bn):
    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['params']['prune_batch_norms'] = prune_bn
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['prune_last_conv'] = True

    pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config(
        config)
    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    assert len(pruned_modules) == len(nncf_modules)

    for module in pruned_modules:
        op = list(module.pre_ops.values())[0]
        mask = op.operand.binary_filter_pruning_mask
        masked_weight = apply_filter_binary_mask(mask, module.weight)
        masked_bias = apply_filter_binary_mask(mask, module.bias)
        assert torch.allclose(module.weight, masked_weight)
        assert torch.allclose(module.bias, masked_bias)

    # Have only one BN node in graph
    bn_module = pruned_model.bn
    conv_for_bn = pruned_model.conv2
    bn_mask = list(
        conv_for_bn.pre_ops.values())[0].operand.binary_filter_pruning_mask
    if prune_bn:
        masked_bn_weight = apply_filter_binary_mask(bn_mask, bn_module.weight)
        masked_bn_bias = apply_filter_binary_mask(bn_mask, bn_module.bias)
        assert torch.allclose(bn_module.weight, masked_bn_weight)
        assert torch.allclose(bn_module.bias, masked_bn_bias)
コード例 #2
0
ファイル: test_layers.py プロジェクト: zmaslova/nncf_pytorch
def test_assert_broadcastable_mask_and_weight_shape():
    nncf_module = NNCFConv2d(1, 2, 2)
    fill_conv_weight(nncf_module, 1)
    fill_bias(nncf_module, 1)

    mask = torch.zeros(10)

    with pytest.raises(RuntimeError):
        inplace_apply_filter_binary_mask(mask, nncf_module.weight.data)

    with pytest.raises(RuntimeError):
        apply_filter_binary_mask(mask, nncf_module.weight.data)
コード例 #3
0
ファイル: test_layers.py プロジェクト: zmaslova/nncf_pytorch
    def test_apply_filter_binary_mask(mask, reference_weight, reference_bias):
        """
        Test that apply_filter_binary_mask not changes the input weight and returns valid result.
        """
        nncf_module = NNCFConv2d(1, 2, 2)
        fill_conv_weight(nncf_module, 1)
        fill_bias(nncf_module, 1)

        original_weight = nncf_module.weight.data.detach().clone()
        original_bias = nncf_module.bias.data.detach().clone()

        result = apply_filter_binary_mask(mask, nncf_module.weight.data)
        assert torch.allclose(nncf_module.weight, original_weight)
        assert torch.allclose(result, reference_weight)

        result_bias = apply_filter_binary_mask(mask, nncf_module.bias.data)
        assert torch.allclose(result_bias, reference_bias)
        assert torch.allclose(nncf_module.bias, original_bias)
コード例 #4
0
def test_zeroing_gradients(zero_grad):
    """
    Test for zeroing gradients functionality (zero_grads_for_pruned_modules in base algo)
    :param zero_grad: zero grad or not
    """
    config = get_basic_pruning_config(input_sample_size=(2, 1, 8, 8))
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['prune_last_conv'] = True
    config['compression']['params']['zero_grad'] = zero_grad

    pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config)
    assert pruning_algo.zero_grad is zero_grad

    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    device = next(pruned_model.parameters()).device
    data_loader = create_dataloader(config)

    pruning_algo.initialize(data_loader)

    params_to_optimize = get_parameter_groups(pruned_model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    lr_scheduler.step(0)

    pruned_model.train()
    for input_, target in data_loader:
        input_ = input_.to(device)
        target = target.to(device).view(1)

        output = pruned_model(input_)

        loss = torch.sum(target.to(torch.float32) - output)

        optimizer.zero_grad()
        loss.backward()

        # In case of zero_grad = True gradients should be masked
        if zero_grad:
            for module in pruned_modules:
                op = list(module.pre_ops.values())[0]
                mask = op.operand.binary_filter_pruning_mask
                grad = module.weight.grad
                masked_grad = apply_filter_binary_mask(mask, grad)
                assert torch.allclose(masked_grad, grad)
コード例 #5
0
 def hook(grad, mask):
     mask = mask.to(grad.device)
     return apply_filter_binary_mask(mask, grad)