def test_check_default_algo_params(): """ Test for default algorithm params. Creating empty config and check for valid default parameters. """ # Creating algorithm with empty config config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' model = PruningTestModel() _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) assert isinstance(compression_ctrl, FilterPruningController) scheduler = compression_ctrl.scheduler # Check default algo params assert compression_ctrl.prune_first is False assert compression_ctrl.prune_batch_norms is True assert compression_ctrl.prune_downsample_convs is False assert compression_ctrl.filter_importance is l2_filter_norm assert compression_ctrl.ranking_type == 'unweighted_ranking' assert compression_ctrl.pruning_quota == 0.9 assert compression_ctrl.all_weights is False # Check default scheduler params assert isinstance(scheduler, ExponentialPruningScheduler)
def test_default_legr_init_struct(): config = get_basic_pruning_config() init_loader = create_ones_mock_dataloader(config) nncf_config = register_default_init_args(config, init_loader) with pytest.raises(KeyError): nncf_config.get_extra_struct(LeGRInitArgs)
def test_valid_masks_for_bn_after_concat(prune_bn): config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['prune_batch_norms'] = prune_bn config['compression']['params']['prune_first_conv'] = True config['compression']['pruning_init'] = 0.5 model = PruningTestModelConcatBN() pruned_model, _ = create_compressed_model_and_algo_for_test(model, config) bn_modules = [pruned_model.bn, pruned_model.bn1, pruned_model.bn2] for bn_module in bn_modules: if prune_bn: # Check that mask was applied for batch_norm module mask = bn_module.pre_ops['0'].op.binary_filter_pruning_mask assert sum(mask) == len(mask) * 0.5 else: # Check that no mask was added to the layer assert len(bn_module.pre_ops) == 0 # Check output mask of concat layers ref_concat_masks = [[0] * 8 + [1] * 8 + [0] * 8 + [1] * 8, [1] * 8 + [0] * 16 + [1] * 8 + [0] * 8 + [1] * 8] graph = pruned_model.get_original_graph() for i, node in enumerate(graph.get_nodes_by_types(['cat'])): assert np.allclose(node.data['output_mask'].tensor.numpy(), ref_concat_masks[i])
def test_default_distributed_init_struct(): config = get_basic_pruning_config() init_loader = create_ones_mock_dataloader(config) register_default_init_args(config, init_loader) dist_callbacks = config.get_extra_struct(DistributedCallbacksArgs) assert callable(dist_callbacks.wrap_model) assert callable(dist_callbacks.unwrap_model)
def test_can_choose_scheduler(algo, scheduler, scheduler_class): config = get_basic_pruning_config() config['compression']['algorithm'] = algo config['compression']['params']['schedule'] = scheduler model = PruningTestModel() _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) scheduler = compression_ctrl.scheduler assert isinstance(scheduler, scheduler_class)
def test_both_targets_assert(): config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['pruning_target'] = 0.3 config['compression']['params']['pruning_flops_target'] = 0.5 model = PruningTestModel() with pytest.raises(ValueError): create_compressed_model_and_algo_for_test(model, config)
def create_nncf_model_and_pruning_builder( model: torch.nn.Module, config_params: Dict) -> Tuple[NNCFNetwork, FilterPruningBuilder]: nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) nncf_config['compression']['algorithm'] = 'filter_pruning' for key, value in config_params.items(): nncf_config['compression']['params'][key] = value nncf_model, pruning_builder = create_nncf_model_and_single_algo_builder( model, nncf_config) return nncf_model, pruning_builder
def test_check_default_scheduler_params(algo, ref_scheduler, ref_scheduler_params): config = get_basic_pruning_config() config['compression']['algorithm'] = algo model = PruningTestModel() _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) scheduler = compression_ctrl.scheduler assert isinstance(scheduler, ref_scheduler) for key, value in ref_scheduler_params.items(): assert getattr(scheduler, key) == value
def test_init_params_for_flops_calculation(model, ref_params): config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['pruning_flops_target'] = 0.3 config['compression']['params']['prune_first_conv'] = True model = model() _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) for key, value in ref_params.items(): assert getattr(compression_ctrl, key) == value
def test_can_set_compression_rate_in_filter_pruning_algo(): """ Test setting the global pruning rate via the compression_rate property. """ # Creating algorithm with empty config config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['pruning_init'] = 0.2 _, pruning_controller, _ = create_pruning_algo_with_config(config) pruning_controller.compression_rate = 0.65 assert pytest.approx(pruning_controller.compression_rate, 1e-2) == 0.65
def test_pruning_masks_correctness(all_weights, pruning_flops_target, prune_first, ref_masks): """ Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together). :param all_weights: whether mask will be calculated for all weights in common or not :param pruning_flops_target: prune model by flops, if None then by number of channels :param prune_first: whether to prune first convolution or not :param ref_masks: reference masks values """ def check_mask(module, num): pruning_op = list(module.pre_ops.values())[0].operand assert hasattr(pruning_op, 'binary_filter_pruning_mask') assert torch.allclose(pruning_op.binary_filter_pruning_mask, ref_masks[num]) config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['params']['all_weights'] = all_weights config['compression']['params']['prune_first_conv'] = prune_first config['compression']['pruning_init'] = 0.5 if pruning_flops_target: config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config) pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes() pruned_modules = [minfo.module for minfo in pruned_module_info] assert pruning_algo.pruning_level == 0.5 assert pruning_algo.all_weights is all_weights i = 0 # ref_masks Check for conv1 conv1 = pruned_model.conv1 conv_depthwise = pruned_model.conv_depthwise if prune_first: assert conv1 in pruned_modules assert conv_depthwise in pruned_modules check_mask(conv1, i) check_mask(conv_depthwise, i) i += 1 # Check for conv2 conv2 = pruned_model.conv2 assert conv2 in pruned_modules check_mask(conv2, i) i += 1 # Check for conv3 up = pruned_model.up assert up in pruned_modules check_mask(up, i)
def test_pruning_export_eltwise_model(tmp_path, prune_first, ref_shapes): model = PruningTestModelEltwise() nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) nncf_config['compression']['algorithm'] = 'filter_pruning' nncf_config['compression']['params']['prune_first_conv'] = prune_first nncf_config['compression']['pruning_init'] = 0.5 onnx_model_proto = load_exported_onnx_version(nncf_config, model, path_to_storage_dir=tmp_path) for i in range(1, 5): conv_name = "nncf_module.conv{}".format(i) check_bias_and_weight_shape(conv_name, onnx_model_proto, *ref_shapes[i - 1])
def test_pruning_export_groupnorm_model(tmp_path): model = TestModelGroupNorm() nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) nncf_config['compression']['algorithm'] = 'filter_pruning' nncf_config['compression']['params']['prune_first_conv'] = True nncf_config['compression']['pruning_init'] = 0.5 onnx_model_proto = load_exported_onnx_version(nncf_config, model, path_to_storage_dir=tmp_path) check_bias_and_weight_shape("nncf_module.conv1", onnx_model_proto, [8, 1, 1, 1], [8]) check_bias_and_weight_shape("nncf_module.conv2", onnx_model_proto, [16, 8, 1, 1], [16])
def test_prune_flops_param(pruning_target, pruning_flops_target, prune_flops_ref, pruning_target_ref): config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' if pruning_target: config['compression']['params']['pruning_target'] = pruning_target if pruning_flops_target: config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target config['compression']['params']['prune_first_conv'] = True model = PruningTestModel() _, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) assert compression_ctrl.prune_flops is prune_flops_ref assert compression_ctrl.scheduler.target_level == pruning_target_ref
def test_legr_coeffs_saving(tmp_path): file_name = tmp_path / 'ranking_coeffs.json' model = PruningTestModel() ref_ranking_coeffs = {model.CONV_1_NODE_NAME: (1, 0), model.CONV_2_NODE_NAME: (1, 0)} config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['prune_first_conv'] = True config['compression']['params']['save_ranking_coeffs_path'] = str(file_name) _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) assert compression_ctrl.ranking_coeffs == ref_ranking_coeffs # check that in specified file some coeffs are saved (1, 0 in case of not-legr) with open(file_name, 'r', encoding='utf8') as f: saved_coeffs_in_file = json.load(f) assert all(ref_ranking_coeffs[key] == tuple(saved_coeffs_in_file[key]) for key in saved_coeffs_in_file)
def test_setting_pruning_rate(all_weights, pruning_rate_to_set, ref_pruning_rates, ref_global_pruning_rate): """ Test setting global and groupwise pruning rates via the set_pruning_rate method. """ # Creating algorithm with empty config config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['pruning_init'] = 0.2 config['compression']['params']['all_weights'] = all_weights _, pruning_controller, _ = create_pruning_algo_with_config(config) assert isinstance(pruning_controller, FilterPruningController) pruning_controller.set_pruning_level(pruning_rate_to_set) groupwise_pruning_rates = list(pruning_controller.current_groupwise_pruning_level.values()) assert np.isclose(groupwise_pruning_rates, ref_pruning_rates).all() assert np.isclose(pruning_controller.pruning_level, ref_global_pruning_rate).all()
def test_valid_legr_init_struct(): config = get_basic_pruning_config() train_loader = create_ones_mock_dataloader(config) val_loader = create_ones_mock_dataloader(config) train_steps_fn = lambda *x: None validate_fn = lambda *x: (0, 0, 0) nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn, val_loader=val_loader, validate_fn=validate_fn) legr_init_args = config.get_extra_struct(LeGRInitArgs) assert legr_init_args.config == nncf_config assert legr_init_args.train_loader == train_loader assert legr_init_args.val_loader == val_loader assert legr_init_args.train_steps_fn == train_steps_fn
def test_disconnected_graph(): config = get_basic_pruning_config([1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['pruning_init'] = 0.5 config['compression']['params']['pruning_target'] = 0.5 config['compression']['params']['prune_first_conv'] = True model = DisconectedGraphModel() pruned_model, _ = create_compressed_model_and_algo_for_test(model, config) graph = pruned_model.get_original_graph() conv1 = graph.get_node_by_name( 'DisconectedGraphModel/NNCFConv2d[conv1]/conv2d_0') conv2 = graph.get_node_by_name( 'DisconectedGraphModel/NNCFConv2d[conv2]/conv2d_0') assert sum(conv1.data['output_mask'].tensor) == 8 assert sum(conv2.data['output_mask'].tensor) == 8
def test_get_last_pruned_layers(model, ref_last_module_names): config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' pruned_model, _ = create_compressed_model_and_algo_for_test( model(), config) last_pruned_nodes = get_last_nodes_of_type( pruned_model.get_original_graph(), FilterPruningBuilder(config).get_op_types_of_pruned_modules()) last_pruned_modules = [ pruned_model.get_containing_module(n.node_name) for n in last_pruned_nodes ] ref_last_modules = [ getattr(pruned_model, module_name) for module_name in ref_last_module_names ] assert set(last_pruned_modules) == set(ref_last_modules)
def test_func_calulation_flops_for_conv(model): # Check _calculate_output_shape that used for disconnected graph config = get_basic_pruning_config([1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['pruning_init'] = 0.4 config['compression']['params']['pruning_flops_target'] = 0.4 model = model() pruned_model, pruning_algo = create_compressed_model_and_algo_for_test( model, config) graph = pruned_model.get_original_graph() # pylint:disable=protected-access for node_name, ref_shape in pruning_algo._modules_out_shapes.items(): # ref_shape get from tracing graph node = graph.get_node_by_name(node_name) shape = pruning_algo._calculate_output_shape(graph, node) assert ref_shape == shape, f"Incorrect calculation output name for {node_name}"
def test_get_bn_for_conv_node(): config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' pruned_model, _ = create_compressed_model_and_algo_for_test( BigPruningTestModel(), config) conv1_name = 'BigPruningTestModel/NNCFConv2d[conv1]/conv2d_0' bn = get_bn_for_conv_node_by_name(pruned_model, conv1_name) assert bn == pruned_model.bn1 conv2_name = 'BigPruningTestModel/NNCFConv2d[conv2]/conv2d_0' bn = get_bn_for_conv_node_by_name(pruned_model, conv2_name) assert bn == pruned_model.bn2 up_name = 'BigPruningTestModel/NNCFConvTranspose2d[up]/conv_transpose2d_0' bn = get_bn_for_conv_node_by_name(pruned_model, up_name) assert bn is None conv3_name = 'BigPruningTestModel/NNCFConv2d[conv3]/conv2d_0' bn = get_bn_for_conv_node_by_name(pruned_model, conv3_name) assert bn is None
def test_pruning_export_simple_model(tmp_path): model = BigPruningTestModel() nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) nncf_config['compression']['pruning_init'] = 0.5 nncf_config['compression']['algorithm'] = 'filter_pruning' onnx_model_proto = load_exported_onnx_version(nncf_config, model, path_to_storage_dir=tmp_path) # Check that conv2 + BN were pruned by output filters # WARNING: starting from at least torch 1.7.0, torch.onnx.export will fuses BN into previous # convs if torch.onnx.export is done with `training=False`, so this test might fail. check_bias_and_weight_shape('nncf_module.conv2', onnx_model_proto, [16, 16, 3, 3], [16]) check_bias_and_weight_shape('nncf_module.bn', onnx_model_proto, [16], [16]) # Check that up was pruned by input filters check_bias_and_weight_shape('nncf_module.up', onnx_model_proto, [16, 32, 3, 3], [32]) # Check that conv3 was pruned by input filters check_bias_and_weight_shape('nncf_module.conv3', onnx_model_proto, [1, 32, 5, 5], [1])
def test_clusters_for_multiple_forward(repeat_seq_of_shared_convs, ref_second_cluster, additional_last_shared_layers): config = get_basic_pruning_config(input_sample_size=[1, 2, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['all_weights'] = False config['compression']['params']['prune_first_conv'] = True config['compression']['pruning_init'] = 0.5 model = TestModelMultipleForward(repeat_seq_of_shared_convs, additional_last_shared_layers) _, pruning_algo = create_compressed_model_and_algo_for_test(model, config) clusters = pruning_algo.pruned_module_groups_info.clusters ref_num_clusters = 2 if additional_last_shared_layers else 1 assert len(clusters) == ref_num_clusters # Convolutions before one node that forwards several times should be in one cluster assert sorted([n.nncf_node_id for n in clusters[0].elements]) == [1, 2, 3] # In case of two clusters if additional_last_shared_layers: # Nodes that associate with one module should be in one cluster assert sorted([n.nncf_node_id for n in clusters[1].elements]) == ref_second_cluster
def test_calculation_of_flops(all_weights, pruning_flops_target, ref_flops, ref_params_num): """ Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together). :param all_weights: whether mask will be calculated for all weights in common or not :param pruning_flops_target: prune model by flops, if None then by number of channels :param ref_flops: reference size of model """ config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['params']['all_weights'] = all_weights config['compression']['params']['prune_first_conv'] = True config['compression']['pruning_init'] = 0.5 if pruning_flops_target: config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target _, pruning_algo, _ = create_pruning_algo_with_config(config) assert pruning_algo.current_flops == ref_flops assert pruning_algo.current_params_num == ref_params_num # pylint:disable=protected-access tmp_in_channels, tmp_out_channels = calculate_in_out_channels_by_masks( pruning_algo.pruned_module_groups_info.get_all_clusters(), masks=pruning_algo._collect_pruning_masks(), tensor_processor=PTNNCFCollectorTensorProcessor, full_input_channels=pruning_algo._modules_in_channels, full_output_channels=pruning_algo._modules_out_channels, pruning_groups_next_nodes=pruning_algo.next_nodes) cur_flops, cur_params_num = count_flops_and_weights( pruning_algo._model.get_original_graph(), pruning_algo._modules_in_shapes, pruning_algo._modules_out_shapes, input_channels=tmp_in_channels, output_channels=tmp_out_channels, conv_op_metatypes=GENERAL_CONV_LAYER_METATYPES, linear_op_metatypes=LINEAR_LAYER_METATYPES) assert (cur_flops, cur_params_num) == (ref_flops, ref_params_num)
def test_distributed_init_struct(): class FakeModelClass(): def __init__(self, model_: nn.Module): self.model = model_ def unwrap(self): return self.model config = get_basic_pruning_config() init_loader = create_ones_mock_dataloader(config) wrapper_callback = FakeModelClass unwrapper_callback = lambda x: x.unwrap() nncf_config = register_default_init_args( config, init_loader, distributed_callbacks=(wrapper_callback, unwrapper_callback)) dist_callbacks = nncf_config.get_extra_struct(DistributedCallbacksArgs) model = PruningTestModel() wrapped_model = dist_callbacks.wrap_model(model) assert isinstance(wrapped_model, FakeModelClass) unwrapped_model = dist_callbacks.unwrap_model(wrapped_model) assert unwrapped_model == model
def test_groups(test_input_info_struct_: GroupPruningModulesTestStruct): model = test_input_info_struct_.model non_pruned_module_nodes = test_input_info_struct_.non_pruned_module_nodes pruned_groups = test_input_info_struct_.pruned_groups prune_first, prune_downsample = test_input_info_struct_.prune_params model = model() nncf_config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) nncf_config['compression']['algorithm'] = 'filter_pruning' nncf_config['compression']['params']['prune_first_conv'] = prune_first nncf_config['compression']['params'][ 'prune_downsample_convs'] = prune_downsample compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test( model, nncf_config) # 1. Check all not pruned modules clusters = compression_ctrl.pruned_module_groups_info all_pruned_modules_info = clusters.get_all_nodes() all_pruned_modules = [info.module for info in all_pruned_modules_info] for node_name in non_pruned_module_nodes: module = compressed_model.get_containing_module(node_name) assert module is not None and module not in all_pruned_modules # 2. Check that all pruned groups are valid for group in pruned_groups: first_node_name = group[0] cluster = clusters.get_cluster_containing_element(first_node_name) cluster_modules = [n.module for n in cluster.elements] group_modules = [ compressed_model.get_containing_module(node_name) for node_name in group ] assert Counter(cluster_modules) == Counter(group_modules) assert len(pruned_groups) == len(clusters.get_all_clusters())
def test_flops_calulation_for_spec_layers(model, all_weights, pruning_flops_target, ref_full_flops, ref_current_flops, ref_sizes): # Need check models with large size of layers because in other case # different value of pruning rate give the same final size of model config = get_basic_pruning_config([1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['pruning_init'] = pruning_flops_target config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target config['compression']['params']['prune_first_conv'] = True config['compression']['params']['all_weights'] = all_weights model = model() compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) assert compression_ctrl.full_flops == ref_full_flops assert compression_ctrl.current_flops == ref_current_flops for i, ref_size in enumerate(ref_sizes): node = getattr(compressed_model, f"conv{i+1}") op = list(node.pre_ops.values())[0] mask = op.operand.binary_filter_pruning_mask assert int(sum(mask)) == ref_size
def create_default_legr_config(): config = get_basic_pruning_config() config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['interlayer_ranking_type'] = 'learned_ranking' return config
def test_pruning_masks_applying_correctness(all_weights, pruning_flops_target, prune_first, ref_masks): """ Test for pruning masks check (_set_binary_masks_for_filters, _set_binary_masks_for_all_filters_together). :param all_weights: whether mask will be calculated for all weights in common or not. :param pruning_flops_target: prune model by flops, if None then by number of channels. :param prune_first: whether to prune first convolution or not. :param ref_masks: reference masks values. """ input_shapes = { 'conv1': [1, 1, 8, 8], 'conv_depthwise': [1, 16, 7, 7], 'conv2': [1, 16, 8, 8], 'bn1': [1, 16, 8, 8], 'bn2': [1, 32, 8, 8], 'up': [1, 32, 8, 8] } def check_mask(module, num): # Mask for weights pruning_op = list(module.pre_ops.values())[0].operand assert hasattr(pruning_op, 'binary_filter_pruning_mask') assert torch.allclose(pruning_op.binary_filter_pruning_mask, ref_masks[num]) # Mask for bias # pruning_op = list(module.pre_ops.values())[1].operand # assert hasattr(pruning_op, 'binary_filter_pruning_mask') # assert torch.allclose(pruning_op.binary_filter_pruning_mask, ref_masks[num]) def check_module_output(module, name, num): """ Checks that output of module are masked. """ mask = ref_masks[num] input_ = torch.ones(input_shapes[name]) output = module(input_) ref_output = apply_filter_binary_mask(mask, output, dim=1) assert torch.allclose(output, ref_output) def check_model_weights(model_state_dict, ref_state_dict): for key in ref_state_dict.keys(): assert torch.allclose(model_state_dict['nncf_module.' + key], ref_state_dict[key]) config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['algorithm'] = 'filter_pruning' config['compression']['params']['all_weights'] = all_weights config['compression']['params']['prune_first_conv'] = prune_first config['compression']['pruning_init'] = 0.5 if pruning_flops_target: config['compression']['params'][ 'pruning_flops_target'] = pruning_flops_target model = BigPruningTestModel() ref_state_dict = deepcopy(model.state_dict()) pruned_model, pruning_algo = create_compressed_model_and_algo_for_test( BigPruningTestModel(), config) pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes() pruned_modules = [minfo.module for minfo in pruned_module_info] assert pruning_algo.pruning_level == 0.5 assert pruning_algo.all_weights is all_weights # Checking that model weights remain unchanged check_model_weights(pruned_model.state_dict(), ref_state_dict) i = 0 # ref_masks Check for conv1 conv1 = pruned_model.conv1 conv_depthwise = pruned_model.conv_depthwise if prune_first: assert conv1 in pruned_modules assert conv_depthwise in pruned_modules check_mask(conv1, i) check_mask(conv_depthwise, i) check_module_output(conv1, 'conv1', i) check_module_output(conv_depthwise, 'conv_depthwise', i) # Check for bn1 bn1 = pruned_model.bn1 if prune_first: check_mask(bn1, i) check_module_output(bn1, 'bn1', i) i += 1 # Check for conv2 conv2 = pruned_model.conv2 assert conv2 in pruned_modules check_mask(conv2, i) check_module_output(conv2, 'conv2', i) # Check for bn2 bn2 = pruned_model.bn2 check_mask(bn2, i) check_module_output(bn2, 'bn2', i) i += 1 # Check for up conv up = pruned_model.up assert up in pruned_modules check_mask(up, i) check_module_output(up, 'up', i)
def test_valid_modules_replacement_and_pruning(prune_first, prune_batch_norms): """ Test that checks that all conv modules in model was replaced by nncf modules and pruning pre ops were added correctly. :param prune_first: whether to prune first convolution or not :param prune_batch_norms: whether to prune batch norm layers or not. """ def check_that_module_is_pruned(module): assert len(module.pre_ops.values()) == 1 pre_ops = list(module.pre_ops.values()) assert isinstance(pre_ops[0], UpdateWeightAndBias) pruning_op = pre_ops[0].operand assert isinstance(pruning_op, FilterPruningMask) def check_that_module_is_not_pruned(module): assert len(module.pre_ops) == 0 assert len(module.post_ops) == 0 config = get_basic_pruning_config(input_sample_size=[1, 1, 8, 8]) config['compression']['params']['prune_first_conv'] = prune_first config['compression']['params']['prune_batch_norms'] = prune_batch_norms pruned_model, pruning_algo, nncf_modules = create_pruning_algo_with_config( config) pruned_module_info = pruning_algo.pruned_module_groups_info.get_all_nodes() pruned_modules = [minfo.module for minfo in pruned_module_info] # Check for conv1 and conv_depthwise conv1 = pruned_model.conv1 conv_depthwise = pruned_model.conv_depthwise if prune_first: assert conv1 in pruned_modules assert conv1 in nncf_modules.values() check_that_module_is_pruned(conv1) assert conv_depthwise in nncf_modules.values() check_that_module_is_pruned(conv_depthwise) else: check_that_module_is_not_pruned(conv1) check_that_module_is_not_pruned(conv_depthwise) # Check for bn1 bn1 = pruned_model.bn1 if prune_first and prune_batch_norms: assert bn1 in nncf_modules.values() check_that_module_is_pruned(bn1) else: check_that_module_is_not_pruned(bn1) # Check for conv2 conv2 = pruned_model.conv2 assert conv2 in pruned_modules assert conv2 in nncf_modules.values() check_that_module_is_pruned(conv2) # Check for bn2 bn2 = pruned_model.bn2 if prune_batch_norms: assert bn2 in nncf_modules.values() check_that_module_is_pruned(bn2) else: check_that_module_is_not_pruned(bn2) # Check for conv3 up = pruned_model.up assert up in pruned_modules assert up in nncf_modules.values() check_that_module_is_pruned(up) # Check for conv3W conv3 = pruned_model.conv3 check_that_module_is_not_pruned(conv3)