def test_groups(test_input_info_struct_: GroupPruningModulesTestStruct):
    model = test_input_info_struct_.model
    not_pruned_modules = test_input_info_struct_.not_pruned_modules
    pruned_groups = test_input_info_struct_.pruned_groups
    prune_first, prune_last, prune_downsample = test_input_info_struct_.prune_params

    model = model()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    nncf_config['compression']['params']['prune_first_conv'] = prune_first
    nncf_config['compression']['params']['prune_last_conv'] = prune_last
    nncf_config['compression']['params']['prune_downsample_convs'] = prune_downsample

    compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, nncf_config)

    # 1. Check all not pruned modules
    clusters = compression_ctrl.pruned_module_groups_info
    all_pruned_modules_info = clusters.get_all_nodes()
    all_pruned_modules = [info.module for info in all_pruned_modules_info]
    print([minfo.module_scope for minfo in all_pruned_modules_info])
    for module_name in not_pruned_modules:
        module = compressed_model.get_module_by_scope(Scope.from_str(module_name))
        assert module not in all_pruned_modules

    # 2. Check that all pruned groups are valid
    for group in pruned_groups:
        first_node_scope = Scope.from_str(group[0])
        cluster = clusters.get_cluster_by_node_id(first_node_scope)
        cluster_modules = [n.module for n in cluster.nodes]
        group_modules = [compressed_model.get_module_by_scope(Scope.from_str(module_scope)) for module_scope in group]

        assert cluster_modules == group_modules
Beispiel #2
0
def test_can_choose_scheduler(algo, scheduler, scheduler_class):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    config['compression']['params']['schedule'] = scheduler
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, scheduler_class)
Beispiel #3
0
def test_both_targets_assert():
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['pruning_target'] = 0.3
    config['compression']['params']['pruning_flops_target'] = 0.5

    model = PruningTestModel()
    with pytest.raises(ValueError):
        create_compressed_model_and_algo_for_test(model, config)
def create_nncf_model_and_builder(model, config_params):
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    for key, value in config_params.items():
        nncf_config['compression']['params'][key] = value
    nncf_model, algo_builders_list = create_nncf_model_and_algo_builder(model, nncf_config)

    assert len(algo_builders_list) == 1
    algo_builder = algo_builders_list[0]
    return nncf_model, algo_builder
Beispiel #5
0
def test_init_params_for_flops_calculation(model, ref_params):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    config['compression']['params']['pruning_flops_target'] = 0.3
    config['compression']['params']['prune_first_conv'] = True

    model = model()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)
    for key, value in ref_params.items():
        assert getattr(compression_ctrl, key) == value
Beispiel #6
0
def test_check_default_scheduler_params(algo, ref_scheduler,
                                        ref_scheduler_params):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = algo
    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    scheduler = compression_ctrl.scheduler
    assert isinstance(scheduler, ref_scheduler)
    for key, value in ref_scheduler_params.items():
        assert getattr(scheduler, key) == value
def test_pruning_export_simple_model(tmp_path):
    model = BigPruningTestModel()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    onnx_model_proto = load_exported_onnx_version(nncf_config, model,
                                                  path_to_storage_dir=tmp_path)
    # Check that conv2 + BN were pruned by output filters
    check_bias_and_weight_shape('nncf_module.conv2', onnx_model_proto, [16, 16, 3, 3], [16])
    check_bias_and_weight_shape('nncf_module.bn', onnx_model_proto, [16], [16])

    # Check that conv3 was pruned by input filters
    check_bias_and_weight_shape('nncf_module.conv3', onnx_model_proto, [1, 16, 5, 5], [1])
def test_pruning_export_eltwise_model(tmp_path, prune_first, prune_last, ref_shapes):
    model = PruningTestModelEltwise()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['algorithm'] = 'filter_pruning'

    nncf_config['compression']['params']['prune_first_conv'] = prune_first
    nncf_config['compression']['params']['prune_last_conv'] = prune_last

    onnx_model_proto = load_exported_onnx_version(nncf_config, model,
                                                  path_to_storage_dir=tmp_path)
    for i in range(1, 5):
        conv_name = "nncf_module.conv{}".format(i)
        check_bias_and_weight_shape(conv_name, onnx_model_proto, *ref_shapes[i - 1])
Beispiel #9
0
def test_prune_flops_param(pruning_target, pruning_flops_target, prune_flops_ref, pruning_target_ref):
    config = get_basic_pruning_config()
    config['compression']['algorithm'] = 'filter_pruning'
    if pruning_target:
        config['compression']['params']['pruning_target'] = pruning_target
    if pruning_flops_target:
        config['compression']['params']['pruning_flops_target'] = pruning_flops_target
    config['compression']['params']['prune_first_conv'] = True

    model = PruningTestModel()
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)
    assert compression_ctrl.prune_flops is prune_flops_ref
    assert compression_ctrl.scheduler.pruning_target == pruning_target_ref
Beispiel #10
0
def test_get_last_pruned_layers(model, ref_last_module_names):
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(
        model(), config)

    first_pruned_modules = get_last_pruned_modules(
        pruned_model,
        FilterPruningBuilder(config).get_op_types_of_pruned_modules())
    ref_last_modules = [
        getattr(pruned_model, module_name)
        for module_name in ref_last_module_names
    ]
    assert set(first_pruned_modules) == set(ref_last_modules)
Beispiel #11
0
def test_get_bn_for_module_scope():
    config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    config['compression']['algorithm'] = 'filter_pruning'
    pruned_model, _ = create_compressed_model_and_algo_for_test(
        BigPruningTestModel(), config)

    conv1_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv1]')
    bn = get_bn_for_module_scope(pruned_model, conv1_scope)
    assert bn is None

    conv2_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv2]')
    bn = get_bn_for_module_scope(pruned_model, conv2_scope)
    assert bn == pruned_model.bn

    conv3_scope = Scope.from_str('BigPruningTestModel/NNCFConv2d[conv3]')
    bn = get_bn_for_module_scope(pruned_model, conv3_scope)
    assert bn is None
Beispiel #12
0
def test_pruning_export_simple_model(tmp_path):
    model = BigPruningTestModel()
    nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8])
    nncf_config['compression']['pruning_init'] = 0.5
    nncf_config['compression']['algorithm'] = 'filter_pruning'
    onnx_model_proto = load_exported_onnx_version(nncf_config,
                                                  model,
                                                  path_to_storage_dir=tmp_path)
    # Check that conv2 + BN were pruned by output filters
    # WARNING: starting from at least torch 1.7.0, torch.onnx.export will fuses BN into previous
    # convs if torch.onnx.export is done with `training=False`, so this test might fail.
    check_bias_and_weight_shape('nncf_module.conv2', onnx_model_proto,
                                [16, 16, 3, 3], [16])
    check_bias_and_weight_shape('nncf_module.bn', onnx_model_proto, [16], [16])

    # Check that up was pruned by input filters
    check_bias_and_weight_shape('nncf_module.up', onnx_model_proto,
                                [16, 32, 3, 3], [32])

    # Check that conv3 was pruned by input filters
    check_bias_and_weight_shape('nncf_module.conv3', onnx_model_proto,
                                [1, 32, 5, 5], [1])