Пример #1
0
def test_check_default_algo_params():
    """
    Test for default algorithm params. Creating empty config and check for valid default
    parameters.
    """
    # Creating algorithm with empty config
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    assert isinstance(compression_ctrl, FilterPruningController)
    scheduler = compression_ctrl.scheduler
    # Check default algo params
    assert compression_ctrl.prune_first is False
    assert compression_ctrl.prune_last is False
    assert compression_ctrl.prune_batch_norms is False
    assert compression_ctrl.filter_importance is l2_filter_norm

    assert compression_ctrl.all_weights is False
    assert compression_ctrl.zero_grad is True

    # Check default scheduler params
    assert isinstance(scheduler, BaselinePruningScheduler)
Пример #2
0
def test_applying_masks(prune_bn):
    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['params']['prune_batch_norms'] = prune_bn
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['prune_last_conv'] = True

    pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config(
        config)
    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    assert len(pruned_modules) == len(nncf_modules)

    for module in pruned_modules:
        op = list(module.pre_ops.values())[0]
        mask = op.operand.binary_filter_pruning_mask
        masked_weight = apply_filter_binary_mask(mask, module.weight)
        masked_bias = apply_filter_binary_mask(mask, module.bias)
        assert torch.allclose(module.weight, masked_weight)
        assert torch.allclose(module.bias, masked_bias)

    # Have only one BN node in graph
    bn_module = pruned_model.bn
    conv_for_bn = pruned_model.conv2
    bn_mask = list(
        conv_for_bn.pre_ops.values())[0].operand.binary_filter_pruning_mask
    if prune_bn:
        masked_bn_weight = apply_filter_binary_mask(bn_mask, bn_module.weight)
        masked_bn_bias = apply_filter_binary_mask(bn_mask, bn_module.bias)
        assert torch.allclose(bn_module.weight, masked_bn_weight)
        assert torch.allclose(bn_module.bias, masked_bn_bias)
Пример #3
0
def test_pruning_masks_correctness(all_weights, prune_first, ref_masks):
    """
    Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together).
    :param all_weights: whether mask will be calculated for all weights in common or not
    :param prune_first: whether to prune first convolution or not
    :param ref_masks: reference masks values
    """
    def check_mask(module, num):
        op = list(module.pre_ops.values())[0]
        assert hasattr(op.operand, 'binary_filter_pruning_mask')
        assert torch.allclose(op.operand.binary_filter_pruning_mask,
                              ref_masks[num])

    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['params']['all_weights'] = all_weights
    config['compression']['params']['prune_first_conv'] = prune_first

    pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config)
    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]
    assert pruning_algo.pruning_rate == 0.5
    assert pruning_algo.all_weights is all_weights

    i = 0
    # Check for conv1
    conv1 = pruned_model.conv1
    if prune_first:
        assert conv1 in pruned_modules
        check_mask(conv1, i)
        i += 1

    # Check for conv2
    conv2 = pruned_model.conv2
    assert conv2 in pruned_modules
    check_mask(conv2, i)
Пример #4
0
def test_get_last_pruned_layers(model, ref_last_module_names):
    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(model(), config)

    first_pruned_modules = get_last_pruned_modules(pruned_model,
                                                   FilterPruningBuilder(config).get_types_of_pruned_modules())
    ref_last_modules = [getattr(pruned_model, module_name) for module_name in ref_last_module_names]
    assert set(first_pruned_modules) == set(ref_last_modules)
Пример #5
0
def test_can_choose_scheduler(algo, scheduler, scheduler_class):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    config['compression']['params']['schedule'] = scheduler
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, scheduler_class)
Пример #6
0
def test_check_default_scheduler_params(algo, ref_scheduler,
                                        ref_scheduler_params):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, ref_scheduler)
    for key, value in ref_scheduler_params.items():
        assert getattr(scheduler, key) == value
Пример #7
0
def test_get_bn_for_module_scope():
    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(BigPruningTestModel(), config)

    conv1_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv1]')
    bn = get_bn_for_module_scope(pruned_model, conv1_scope)
    assert bn is None

    conv2_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv2]')
    bn = get_bn_for_module_scope(pruned_model, conv2_scope)
    assert bn == pruned_model.bn

    conv3_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv3]')
    bn = get_bn_for_module_scope(pruned_model, conv3_scope)
    assert bn is None
Пример #8
0
def test_zeroing_gradients(zero_grad):
    """
    Test for zeroing gradients functionality (zero_grads_for_pruned_modules in base algo)
    :param zero_grad: zero grad or not
    """
    config = get_basic_pruning_config(input_sample_size=(2, 1, 8, 8))
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['prune_last_conv'] = True
    config['compression']['params']['zero_grad'] = zero_grad

    pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config)
    assert pruning_algo.zero_grad is zero_grad

    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    device = next(pruned_model.parameters()).device
    data_loader = create_dataloader(config)

    pruning_algo.initialize(data_loader)

    params_to_optimize = get_parameter_groups(pruned_model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    lr_scheduler.step(0)

    pruned_model.train()
    for input_, target in data_loader:
        input_ = input_.to(device)
        target = target.to(device).view(1)

        output = pruned_model(input_)

        loss = torch.sum(target.to(torch.float32) - output)

        optimizer.zero_grad()
        loss.backward()

        # In case of zero_grad = True gradients should be masked
        if zero_grad:
            for module in pruned_modules:
                op = list(module.pre_ops.values())[0]
                mask = op.operand.binary_filter_pruning_mask
                grad = module.weight.grad
                masked_grad = apply_filter_binary_mask(mask, grad)
                assert torch.allclose(masked_grad, grad)
Пример #9
0
def test_valid_modules_replacement_and_pruning(prune_first, prune_last):
    """
    Test that checks that all conv modules in model was replaced by nncf modules and
    pruning pre ops were added correctly.
    :param prune_first: whether to prune first convolution or not
    :param prune_last: whether to prune last convolution or not
    """
    def check_that_module_is_pruned(module):
        assert len(module.pre_ops.values()) == 1
        op = list(module.pre_ops.values())[0]
        assert isinstance(op, UpdateWeight)
        assert isinstance(op.operand, FilterPruningBlock)

    config = get_basic_pruning_config(input_sample_size=(1, 1, 8, 8))
    config['compression']['params']['prune_first_conv'] = prune_first
    config['compression']['params']['prune_last_conv'] = prune_last

    pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config(
        config)
    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    # Check for conv1
    conv1 = pruned_model.conv1
    if prune_first:
        assert conv1 in pruned_modules
        assert conv1 in nncf_modules.values()
        check_that_module_is_pruned(conv1)

    # Check for conv2
    conv2 = pruned_model.conv2
    assert conv2 in pruned_modules
    assert conv2 in nncf_modules.values()
    check_that_module_is_pruned(conv2)

    # Check for conv3
    conv3 = pruned_model.conv3
    if prune_last:
        assert conv3 in pruned_modules
        assert conv3 in nncf_modules.values()
        check_that_module_is_pruned(conv3)