示例#1
0
def test_check_default_algo_params():
    """
    Test for default algorithm params. Creating empty config and check for valid default
    parameters.
    """
    # Creating algorithm with empty config
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    assert isinstance(compression_ctrl, FilterPruningController)
    scheduler = compression_ctrl.scheduler
    # Check default algo params
    assert compression_ctrl.prune_first is False
    assert compression_ctrl.prune_batch_norms is True
    assert compression_ctrl.prune_downsample_convs is False
    assert compression_ctrl.filter_importance is l2_filter_norm
    assert compression_ctrl.ranking_type == 'unweighted_ranking'
    assert compression_ctrl.pruning_quota == 0.9

    assert compression_ctrl.all_weights is False

    # Check default scheduler params
    assert isinstance(scheduler, ExponentialPruningScheduler)
示例#2
0
def test_default_legr_init_struct():
    config = get_basic_pruning_config()
    init_loader = create_ones_mock_dataloader(config)
    nncf_config = register_default_init_args(config, init_loader)

    with pytest.raises(KeyError):
        nncf_config.get_extra_struct(LeGRInitArgs)
示例#3
0
def test_valid_masks_for_bn_after_concat(prune_bn):
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['prune_batch_norms'] = prune_bn
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['pruning_init'] = 0.5
    model = PruningTestModelConcatBN()
    pruned_model, _ = create_compressed_model_and_algo_for_test(model, config)

    bn_modules = [pruned_model.bn, pruned_model.bn1, pruned_model.bn2]
    for bn_module in bn_modules:
        if prune_bn:
            # Check that mask was applied for batch_norm module
            mask = bn_module.pre_ops['0'].op.binary_filter_pruning_mask
            assert sum(mask) == len(mask) * 0.5
        else:
            # Check that no mask was added to the layer
            assert len(bn_module.pre_ops) == 0

    # Check output mask of concat layers
    ref_concat_masks = [[0] * 8 + [1] * 8 + [0] * 8 + [1] * 8,
                        [1] * 8 + [0] * 16 + [1] * 8 + [0] * 8 + [1] * 8]
    graph = pruned_model.get_original_graph()
    for i, node in enumerate(graph.get_nodes_by_types(['cat'])):
        assert np.allclose(node.data['output_mask'].tensor.numpy(),
                           ref_concat_masks[i])
示例#4
0
def test_default_distributed_init_struct():
    config = get_basic_pruning_config()
    init_loader = create_ones_mock_dataloader(config)
    register_default_init_args(config, init_loader)

    dist_callbacks = config.get_extra_struct(DistributedCallbacksArgs)
    assert callable(dist_callbacks.wrap_model)
    assert callable(dist_callbacks.unwrap_model)
示例#5
0
def test_can_choose_scheduler(algo, scheduler, scheduler_class):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    config['compression']['params']['schedule'] = scheduler
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, scheduler_class)
示例#6
0
def test_both_targets_assert():
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['pruning_target'] = 0.3
    config['compression']['params']['pruning_flops_target'] = 0.5

    model = PruningTestModel()
    with pytest.raises(ValueError):
        create_compressed_model_and_algo_for_test(model, config)
示例#7
0
def create_nncf_model_and_pruning_builder(
        model: torch.nn.Module,
        config_params: Dict) -> Tuple[NNCFNetwork, FilterPruningBuilder]:
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    for key, value in config_params.items():
        nncf_config['compression']['params'][key] = value
    nncf_model, pruning_builder = create_nncf_model_and_single_algo_builder(
        model, nncf_config)
    return nncf_model, pruning_builder
示例#8
0
def test_check_default_scheduler_params(algo, ref_scheduler,
                                        ref_scheduler_params):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, ref_scheduler)
    for key, value in ref_scheduler_params.items():
        assert getattr(scheduler, key) == value
示例#9
0
def test_init_params_for_flops_calculation(model, ref_params):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['pruning_flops_target'] = 0.3
    config['compression']['params']['prune_first_conv'] = True

    model = model()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    for key, value in ref_params.items():
        assert getattr(compression_ctrl, key) == value
def test_can_set_compression_rate_in_filter_pruning_algo():
    """
    Test setting the global pruning rate via the compression_rate property.
    """
    # Creating algorithm with empty config
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['pruning_init'] = 0.2

    _, pruning_controller, _ = create_pruning_algo_with_config(config)
    pruning_controller.compression_rate = 0.65
    assert pytest.approx(pruning_controller.compression_rate, 1e-2) == 0.65
示例#11
0
def test_pruning_masks_correctness(all_weights, pruning_flops_target,
                                   prune_first, ref_masks):
    """
    Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together).
    :param all_weights: whether mask will be calculated for all weights in common or not
    :param pruning_flops_target: prune model by flops, if None then by number of channels
    :param prune_first: whether to prune first convolution or not
    :param ref_masks: reference masks values
    """
    def check_mask(module, num):
        pruning_op = list(module.pre_ops.values())[0].operand
        assert hasattr(pruning_op, 'binary_filter_pruning_mask')
        assert torch.allclose(pruning_op.binary_filter_pruning_mask,
                              ref_masks[num])

    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['params']['all_weights'] = all_weights
    config['compression']['params']['prune_first_conv'] = prune_first
    config['compression']['pruning_init'] = 0.5
    if pruning_flops_target:
        config['compression']['params'][
            'pruning_flops_target'] = pruning_flops_target

    pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config)
    pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes()
    pruned_modules = [minfo.module for minfo in pruned_module_info]
    assert pruning_algo.pruning_level == 0.5
    assert pruning_algo.all_weights is all_weights

    i = 0
    # ref_masks Check for conv1
    conv1 = pruned_model.conv1
    conv_depthwise = pruned_model.conv_depthwise
    if prune_first:
        assert conv1 in pruned_modules
        assert conv_depthwise in pruned_modules

        check_mask(conv1, i)
        check_mask(conv_depthwise, i)
        i += 1

    # Check for conv2
    conv2 = pruned_model.conv2
    assert conv2 in pruned_modules
    check_mask(conv2, i)
    i += 1

    # Check for conv3
    up = pruned_model.up
    assert up in pruned_modules
    check_mask(up, i)
示例#12
0
def test_pruning_export_eltwise_model(tmp_path, prune_first, ref_shapes):
    model = PruningTestModelEltwise()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'

    nncf_config['compression']['params']['prune_first_conv'] = prune_first
    nncf_config['compression']['pruning_init'] = 0.5
    onnx_model_proto = load_exported_onnx_version(nncf_config,
                                                  model,
                                                  path_to_storage_dir=tmp_path)
    for i in range(1, 5):
        conv_name = "nncf_module.conv{}".format(i)
        check_bias_and_weight_shape(conv_name, onnx_model_proto,
                                    *ref_shapes[i - 1])
示例#13
0
def test_pruning_export_groupnorm_model(tmp_path):
    model = TestModelGroupNorm()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'

    nncf_config['compression']['params']['prune_first_conv'] = True
    nncf_config['compression']['pruning_init'] = 0.5
    onnx_model_proto = load_exported_onnx_version(nncf_config,
                                                  model,
                                                  path_to_storage_dir=tmp_path)

    check_bias_and_weight_shape("nncf_module.conv1", onnx_model_proto,
                                [8, 1, 1, 1], [8])
    check_bias_and_weight_shape("nncf_module.conv2", onnx_model_proto,
                                [16, 8, 1, 1], [16])
示例#14
0
def test_prune_flops_param(pruning_target, pruning_flops_target,
                           prune_flops_ref, pruning_target_ref):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    if pruning_target:
        config['compression']['params']['pruning_target'] = pruning_target
    if pruning_flops_target:
        config['compression']['params'][
            'pruning_flops_target'] = pruning_flops_target
    config['compression']['params']['prune_first_conv'] = True

    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    assert compression_ctrl.prune_flops is prune_flops_ref
    assert compression_ctrl.scheduler.target_level == pruning_target_ref
示例#15
0
def test_legr_coeffs_saving(tmp_path):
    file_name = tmp_path / 'ranking_coeffs.json'
    model = PruningTestModel()
    ref_ranking_coeffs = {model.CONV_1_NODE_NAME: (1, 0), model.CONV_2_NODE_NAME: (1, 0)}

    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['save_ranking_coeffs_path'] = str(file_name)
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)
    assert compression_ctrl.ranking_coeffs == ref_ranking_coeffs

    # check that in specified file some coeffs are saved (1, 0 in case of not-legr)
    with open(file_name, 'r', encoding='utf8') as f:
        saved_coeffs_in_file = json.load(f)
    assert all(ref_ranking_coeffs[key] == tuple(saved_coeffs_in_file[key]) for key in saved_coeffs_in_file)
def test_setting_pruning_rate(all_weights, pruning_rate_to_set, ref_pruning_rates, ref_global_pruning_rate):
    """
    Test setting global and groupwise pruning rates via the set_pruning_rate method.
    """
    # Creating algorithm with empty config
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['pruning_init'] = 0.2
    config['compression']['params']['all_weights'] = all_weights

    _, pruning_controller, _ = create_pruning_algo_with_config(config)
    assert isinstance(pruning_controller, FilterPruningController)

    pruning_controller.set_pruning_level(pruning_rate_to_set)
    groupwise_pruning_rates = list(pruning_controller.current_groupwise_pruning_level.values())
    assert np.isclose(groupwise_pruning_rates, ref_pruning_rates).all()
    assert np.isclose(pruning_controller.pruning_level, ref_global_pruning_rate).all()
示例#17
0
def test_valid_legr_init_struct():
    config = get_basic_pruning_config()
    train_loader = create_ones_mock_dataloader(config)
    val_loader = create_ones_mock_dataloader(config)
    train_steps_fn = lambda *x: None
    validate_fn = lambda *x: (0, 0, 0)
    nncf_config = register_default_init_args(config,
                                             train_loader=train_loader,
                                             train_steps_fn=train_steps_fn,
                                             val_loader=val_loader,
                                             validate_fn=validate_fn)

    legr_init_args = config.get_extra_struct(LeGRInitArgs)
    assert legr_init_args.config == nncf_config
    assert legr_init_args.train_loader == train_loader
    assert legr_init_args.val_loader == val_loader
    assert legr_init_args.train_steps_fn == train_steps_fn
示例#18
0
def test_disconnected_graph():
    config = get_basic_pruning_config([1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['pruning_init'] = 0.5
    config['compression']['params']['pruning_target'] = 0.5
    config['compression']['params']['prune_first_conv'] = True
    model = DisconectedGraphModel()
    pruned_model, _ = create_compressed_model_and_algo_for_test(model, config)
    graph = pruned_model.get_original_graph()

    conv1 = graph.get_node_by_name(
        'DisconectedGraphModel/NNCFConv2d[conv1]/conv2d_0')
    conv2 = graph.get_node_by_name(
        'DisconectedGraphModel/NNCFConv2d[conv2]/conv2d_0')

    assert sum(conv1.data['output_mask'].tensor) == 8
    assert sum(conv2.data['output_mask'].tensor) == 8
示例#19
0
def test_get_last_pruned_layers(model, ref_last_module_names):
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(
        model(), config)

    last_pruned_nodes = get_last_nodes_of_type(
        pruned_model.get_original_graph(),
        FilterPruningBuilder(config).get_op_types_of_pruned_modules())
    last_pruned_modules = [
        pruned_model.get_containing_module(n.node_name)
        for n in last_pruned_nodes
    ]
    ref_last_modules = [
        getattr(pruned_model, module_name)
        for module_name in ref_last_module_names
    ]
    assert set(last_pruned_modules) == set(ref_last_modules)
示例#20
0
def test_func_calulation_flops_for_conv(model):
    # Check _calculate_output_shape that used for disconnected graph
    config = get_basic_pruning_config([1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['pruning_init'] = 0.4
    config['compression']['params']['pruning_flops_target'] = 0.4
    model = model()
    pruned_model, pruning_algo = create_compressed_model_and_algo_for_test(
        model, config)

    graph = pruned_model.get_original_graph()

    # pylint:disable=protected-access
    for node_name, ref_shape in pruning_algo._modules_out_shapes.items():
        # ref_shape get from tracing graph
        node = graph.get_node_by_name(node_name)
        shape = pruning_algo._calculate_output_shape(graph, node)
        assert ref_shape == shape, f"Incorrect calculation output name for {node_name}"
示例#21
0
def test_get_bn_for_conv_node():
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(
        BigPruningTestModel(), config)

    conv1_name = 'BigPruningTestModel/NNCFConv2d[conv1]/conv2d_0'
    bn = get_bn_for_conv_node_by_name(pruned_model, conv1_name)
    assert bn == pruned_model.bn1

    conv2_name = 'BigPruningTestModel/NNCFConv2d[conv2]/conv2d_0'
    bn = get_bn_for_conv_node_by_name(pruned_model, conv2_name)
    assert bn == pruned_model.bn2

    up_name = 'BigPruningTestModel/NNCFConvTranspose2d[up]/conv_transpose2d_0'
    bn = get_bn_for_conv_node_by_name(pruned_model, up_name)
    assert bn is None

    conv3_name = 'BigPruningTestModel/NNCFConv2d[conv3]/conv2d_0'
    bn = get_bn_for_conv_node_by_name(pruned_model, conv3_name)
    assert bn is None
示例#22
0
def test_pruning_export_simple_model(tmp_path):
    model = BigPruningTestModel()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['pruning_init'] = 0.5
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    onnx_model_proto = load_exported_onnx_version(nncf_config,
                                                  model,
                                                  path_to_storage_dir=tmp_path)
    # Check that conv2 + BN were pruned by output filters
    # WARNING: starting from at least torch 1.7.0, torch.onnx.export will fuses BN into previous
    # convs if torch.onnx.export is done with `training=False`, so this test might fail.
    check_bias_and_weight_shape('nncf_module.conv2', onnx_model_proto,
                                [16, 16, 3, 3], [16])
    check_bias_and_weight_shape('nncf_module.bn', onnx_model_proto, [16], [16])

    # Check that up was pruned by input filters
    check_bias_and_weight_shape('nncf_module.up', onnx_model_proto,
                                [16, 32, 3, 3], [32])

    # Check that conv3 was pruned by input filters
    check_bias_and_weight_shape('nncf_module.conv3', onnx_model_proto,
                                [1, 32, 5, 5], [1])
示例#23
0
def test_clusters_for_multiple_forward(repeat_seq_of_shared_convs,
                                       ref_second_cluster,
                                       additional_last_shared_layers):
    config = get_basic_pruning_config(input_sample_size=[1, 2, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['all_weights'] = False
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['pruning_init'] = 0.5
    model = TestModelMultipleForward(repeat_seq_of_shared_convs,
                                     additional_last_shared_layers)
    _, pruning_algo = create_compressed_model_and_algo_for_test(model, config)

    clusters = pruning_algo.pruned_module_groups_info.clusters
    ref_num_clusters = 2 if additional_last_shared_layers else 1
    assert len(clusters) == ref_num_clusters
    # Convolutions before one node that forwards several times should be in one cluster
    assert sorted([n.nncf_node_id for n in clusters[0].elements]) == [1, 2, 3]
    # In case of two clusters
    if additional_last_shared_layers:
        # Nodes that associate with one module should be in one cluster
        assert sorted([n.nncf_node_id
                       for n in clusters[1].elements]) == ref_second_cluster
示例#24
0
def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops,
                              ref_params_num):
    """
    Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together).
    :param all_weights: whether mask will be calculated for all weights in common or not
    :param pruning_flops_target: prune model by flops, if None then by number of channels
    :param ref_flops: reference size of model
    """
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['params']['all_weights'] = all_weights
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['pruning_init'] = 0.5
    if pruning_flops_target:
        config['compression']['params'][
            'pruning_flops_target'] = pruning_flops_target

    _, pruning_algo, _ = create_pruning_algo_with_config(config)

    assert pruning_algo.current_flops == ref_flops
    assert pruning_algo.current_params_num == ref_params_num
    # pylint:disable=protected-access
    tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks(
        pruning_algo.pruned_module_groups_info.get_all_clusters(),
        masks=pruning_algo._collect_pruning_masks(),
        tensor_processor=PTNNCFCollectorTensorProcessor,
        full_input_channels=pruning_algo._modules_in_channels,
        full_output_channels=pruning_algo._modules_out_channels,
        pruning_groups_next_nodes=pruning_algo.next_nodes)

    cur_flops, cur_params_num = count_flops_and_weights(
        pruning_algo._model.get_original_graph(),
        pruning_algo._modules_in_shapes,
        pruning_algo._modules_out_shapes,
        input_channels=tmp_in_channels,
        output_channels=tmp_out_channels,
        conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES,
        linear_op_metatypes=LINEAR_LAYER_METATYPES)
    assert (cur_flops, cur_params_num) == (ref_flops, ref_params_num)
示例#25
0
def test_distributed_init_struct():
    class FakeModelClass():
        def __init__(self, model_: nn.Module):
            self.model = model_

        def unwrap(self):
            return self.model

    config = get_basic_pruning_config()
    init_loader = create_ones_mock_dataloader(config)
    wrapper_callback = FakeModelClass
    unwrapper_callback = lambda x: x.unwrap()
    nncf_config = register_default_init_args(
        config,
        init_loader,
        distributed_callbacks=(wrapper_callback, unwrapper_callback))

    dist_callbacks = nncf_config.get_extra_struct(DistributedCallbacksArgs)
    model = PruningTestModel()
    wrapped_model = dist_callbacks.wrap_model(model)
    assert isinstance(wrapped_model, FakeModelClass)
    unwrapped_model = dist_callbacks.unwrap_model(wrapped_model)
    assert unwrapped_model == model
示例#26
0
def test_groups(test_input_info_struct_: GroupPruningModulesTestStruct):
    model = test_input_info_struct_.model
    non_pruned_module_nodes = test_input_info_struct_.non_pruned_module_nodes
    pruned_groups = test_input_info_struct_.pruned_groups
    prune_first, prune_downsample = test_input_info_struct_.prune_params

    model = model()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    nncf_config['compression']['params']['prune_first_conv'] = prune_first
    nncf_config['compression']['params'][
        'prune_downsample_convs'] = prune_downsample

    compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, nncf_config)

    # 1. Check all not pruned modules
    clusters = compression_ctrl.pruned_module_groups_info
    all_pruned_modules_info = clusters.get_all_nodes()
    all_pruned_modules = [info.module for info in all_pruned_modules_info]

    for node_name in non_pruned_module_nodes:
        module = compressed_model.get_containing_module(node_name)
        assert module is not None and module not in all_pruned_modules

    # 2. Check that all pruned groups are valid
    for group in pruned_groups:
        first_node_name = group[0]
        cluster = clusters.get_cluster_containing_element(first_node_name)
        cluster_modules = [n.module for n in cluster.elements]
        group_modules = [
            compressed_model.get_containing_module(node_name)
            for node_name in group
        ]

        assert Counter(cluster_modules) == Counter(group_modules)
    assert len(pruned_groups) == len(clusters.get_all_clusters())
示例#27
0
def test_flops_calulation_for_spec_layers(model, all_weights,
                                          pruning_flops_target, ref_full_flops,
                                          ref_current_flops, ref_sizes):
    # Need check models with large size of layers because in other case
    # different value of pruning rate give the same final size of model
    config = get_basic_pruning_config([1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['pruning_init'] = pruning_flops_target
    config['compression']['params'][
        'pruning_flops_target'] = pruning_flops_target
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['all_weights'] = all_weights
    model = model()
    compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    assert compression_ctrl.full_flops == ref_full_flops
    assert compression_ctrl.current_flops == ref_current_flops

    for i, ref_size in enumerate(ref_sizes):
        node = getattr(compressed_model, f"conv{i+1}")
        op = list(node.pre_ops.values())[0]
        mask = op.operand.binary_filter_pruning_mask
        assert int(sum(mask)) == ref_size
示例#28
0
def create_default_legr_config():
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['interlayer_ranking_type'] = 'learned_ranking'
    return config
示例#29
0
def test_pruning_masks_applying_correctness(all_weights, pruning_flops_target,
                                            prune_first, ref_masks):
    """
    Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together).
    :param all_weights: whether mask will be calculated for all weights in common or not.
    :param pruning_flops_target: prune model by flops, if None then by number of channels.
    :param prune_first: whether to prune first convolution or not.
    :param ref_masks: reference masks values.
    """
    input_shapes = {
        'conv1': [1, 1, 8, 8],
        'conv_depthwise': [1, 16, 7, 7],
        'conv2': [1, 16, 8, 8],
        'bn1': [1, 16, 8, 8],
        'bn2': [1, 32, 8, 8],
        'up': [1, 32, 8, 8]
    }

    def check_mask(module, num):
        # Mask for weights
        pruning_op = list(module.pre_ops.values())[0].operand
        assert hasattr(pruning_op, 'binary_filter_pruning_mask')
        assert torch.allclose(pruning_op.binary_filter_pruning_mask,
                              ref_masks[num])

        # Mask for bias
        # pruning_op = list(module.pre_ops.values())[1].operand
        # assert hasattr(pruning_op, 'binary_filter_pruning_mask')
        # assert torch.allclose(pruning_op.binary_filter_pruning_mask, ref_masks[num])

    def check_module_output(module, name, num):
        """
        Checks that output of module are masked.
        """
        mask = ref_masks[num]
        input_ = torch.ones(input_shapes[name])
        output = module(input_)
        ref_output = apply_filter_binary_mask(mask, output, dim=1)
        assert torch.allclose(output, ref_output)

    def check_model_weights(model_state_dict, ref_state_dict):
        for key in ref_state_dict.keys():
            assert torch.allclose(model_state_dict['nncf_module.' + key],
                                  ref_state_dict[key])

    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['all_weights'] = all_weights
    config['compression']['params']['prune_first_conv'] = prune_first
    config['compression']['pruning_init'] = 0.5
    if pruning_flops_target:
        config['compression']['params'][
            'pruning_flops_target'] = pruning_flops_target

    model = BigPruningTestModel()
    ref_state_dict = deepcopy(model.state_dict())
    pruned_model, pruning_algo = create_compressed_model_and_algo_for_test(
        BigPruningTestModel(), config)

    pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes()
    pruned_modules = [minfo.module for minfo in pruned_module_info]
    assert pruning_algo.pruning_level == 0.5
    assert pruning_algo.all_weights is all_weights

    # Checking that model weights remain unchanged
    check_model_weights(pruned_model.state_dict(), ref_state_dict)

    i = 0
    # ref_masks Check for conv1
    conv1 = pruned_model.conv1
    conv_depthwise = pruned_model.conv_depthwise
    if prune_first:
        assert conv1 in pruned_modules
        assert conv_depthwise in pruned_modules

        check_mask(conv1, i)
        check_mask(conv_depthwise, i)

        check_module_output(conv1, 'conv1', i)
        check_module_output(conv_depthwise, 'conv_depthwise', i)

    # Check for bn1
    bn1 = pruned_model.bn1
    if prune_first:
        check_mask(bn1, i)
        check_module_output(bn1, 'bn1', i)
        i += 1

    # Check for conv2
    conv2 = pruned_model.conv2
    assert conv2 in pruned_modules
    check_mask(conv2, i)
    check_module_output(conv2, 'conv2', i)

    # Check for bn2
    bn2 = pruned_model.bn2
    check_mask(bn2, i)
    check_module_output(bn2, 'bn2', i)
    i += 1

    # Check for up conv
    up = pruned_model.up
    assert up in pruned_modules
    check_mask(up, i)
    check_module_output(up, 'up', i)
示例#30
0
def test_valid_modules_replacement_and_pruning(prune_first, prune_batch_norms):
    """
    Test that checks that all conv modules in model was replaced by nncf modules and
    pruning pre ops were added correctly.
    :param prune_first: whether to prune first convolution or not
    :param prune_batch_norms: whether to prune batch norm layers or not.
    """
    def check_that_module_is_pruned(module):
        assert len(module.pre_ops.values()) == 1
        pre_ops = list(module.pre_ops.values())
        assert isinstance(pre_ops[0], UpdateWeightAndBias)
        pruning_op = pre_ops[0].operand
        assert isinstance(pruning_op, FilterPruningMask)

    def check_that_module_is_not_pruned(module):
        assert len(module.pre_ops) == 0
        assert len(module.post_ops) == 0

    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['params']['prune_first_conv'] = prune_first
    config['compression']['params']['prune_batch_norms'] = prune_batch_norms

    pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config(
        config)
    pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes()
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    # Check for conv1 and conv_depthwise
    conv1 = pruned_model.conv1
    conv_depthwise = pruned_model.conv_depthwise
    if prune_first:
        assert conv1 in pruned_modules
        assert conv1 in nncf_modules.values()
        check_that_module_is_pruned(conv1)

        assert conv_depthwise in nncf_modules.values()
        check_that_module_is_pruned(conv_depthwise)
    else:
        check_that_module_is_not_pruned(conv1)
        check_that_module_is_not_pruned(conv_depthwise)

    # Check for bn1
    bn1 = pruned_model.bn1
    if prune_first and prune_batch_norms:
        assert bn1 in nncf_modules.values()
        check_that_module_is_pruned(bn1)
    else:
        check_that_module_is_not_pruned(bn1)

    # Check for conv2
    conv2 = pruned_model.conv2
    assert conv2 in pruned_modules
    assert conv2 in nncf_modules.values()
    check_that_module_is_pruned(conv2)

    # Check for bn2
    bn2 = pruned_model.bn2
    if prune_batch_norms:
        assert bn2 in nncf_modules.values()
        check_that_module_is_pruned(bn2)
    else:
        check_that_module_is_not_pruned(bn2)

    # Check for conv3
    up = pruned_model.up
    assert up in pruned_modules
    assert up in nncf_modules.values()
    check_that_module_is_pruned(up)

    # Check for conv3W
    conv3 = pruned_model.conv3
    check_that_module_is_not_pruned(conv3)