Esempio n. 1
0
    def _sparsify_weights(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)
        insertion_commands = []
        for module_node in sparsified_module_nodes:
            node_name = module_node.node_name

            if not self._should_consider_scope(node_name):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        node_name))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(node_name))
            compression_lr_multiplier = \
                self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
                                                                        self.name)
            operation = self.create_weight_sparsifying_operation(
                module_node, compression_lr_multiplier)
            hook = operation.to(device)
            insertion_commands.append(
                PTInsertionCommand(
                    PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                  target_node_name=node_name), hook,
                    TransformationPriority.SPARSIFICATION_PRIORITY))
            sparsified_module = target_model.get_containing_module(node_name)
            self._sparsified_module_info.append(
                SparseModuleInfo(node_name, sparsified_module, hook))

        return insertion_commands
Esempio n. 2
0
def add_adjust_padding_nodes(bitwidth_graph: nx.DiGraph,
                             model: NNCFNetwork) -> nx.DiGraph():
    # pylint:disable=protected-access

    NewNodeArgs = namedtuple('NewNodeArgs',
                             ('node_key', 'attr', 'parent_node_key'))
    nncf_graph = model.get_graph()
    args = []
    for node_key in bitwidth_graph.nodes:
        node = nncf_graph.get_node_by_key(node_key)
        module = model.get_containing_module(node.node_name)
        if isinstance(module, NNCFConv2d):
            adjust_padding_ops = filter(
                lambda x: isinstance(x, UpdatePaddingValue),
                module.pre_ops.values())
            for _ in adjust_padding_ops:
                new_node_key = f'{node_key}_apad'
                attr = dict(type='',
                            label='adjust_padding_value',
                            style='filled',
                            color='yellow')
                args.append(NewNodeArgs(new_node_key, attr, node_key))

    for arg in args:
        bitwidth_graph.add_node(arg.node_key, **arg.attr)
        bitwidth_graph.add_edge(arg.node_key, arg.parent_node_key)
    return bitwidth_graph
Esempio n. 3
0
    def test_operator_metatype_marking(self):
        from nncf.torch.graph.operator_metatypes import PTConv2dMetatype, PTBatchNormMetatype, PTRELUMetatype, \
            PTMaxPool2dMetatype, PTTransposeMetatype, \
            PTConvTranspose2dMetatype, PTDepthwiseConv2dSubtype, PTAddMetatype, PTAvgPool2dMetatype, PTLinearMetatype
        ref_scope_vs_metatype_dict = {
            "/" + MODEL_INPUT_OP_NAME + "_0": PTInputNoopMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_regular]/conv2d_0": PTConv2dMetatype,
            "ModelForMetatypeTesting/NNCFBatchNorm[bn]/batch_norm_0": PTBatchNormMetatype,
            "ModelForMetatypeTesting/relu_0": PTRELUMetatype,
            "ModelForMetatypeTesting/transpose__0": PTTransposeMetatype,
            "ModelForMetatypeTesting/MaxPool2d[max_pool2d]/max_pool2d_0": PTMaxPool2dMetatype,
            "ModelForMetatypeTesting/NNCFConvTranspose2d[conv_transpose]/conv_transpose2d_0": PTConvTranspose2dMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_depthwise]/conv2d_0": PTDepthwiseConv2dSubtype,
            "ModelForMetatypeTesting/__iadd___0": PTAddMetatype,
            "ModelForMetatypeTesting/AdaptiveAvgPool2d[adaptive_avg_pool]/adaptive_avg_pool2d_0": PTAvgPool2dMetatype,
            "ModelForMetatypeTesting/NNCFLinear[linear]/linear_0": PTLinearMetatype,
            'ModelForMetatypeTesting/flatten_0': PTReshapeMetatype,
            "/" + MODEL_OUTPUT_OP_NAME + "_0": PTOutputNoopMetatype,
        }

        class ModelForMetatypeTesting(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv_regular = torch.nn.Conv2d(in_channels=3,
                                                    out_channels=16,
                                                    kernel_size=3)
                self.bn = torch.nn.BatchNorm2d(num_features=16)
                self.max_pool2d = torch.nn.MaxPool2d(kernel_size=2)
                self.conv_transpose = torch.nn.ConvTranspose2d(in_channels=16,
                                                               out_channels=8,
                                                               kernel_size=3)
                self.conv_depthwise = torch.nn.Conv2d(in_channels=8, out_channels=8,
                                                      kernel_size=5, groups=8)
                self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=1)
                self.linear = torch.nn.Linear(in_features=8, out_features=1)

            def forward(self, input_):
                x = self.conv_regular(input_)
                x = self.bn(x)
                x = torch.nn.functional.relu(x)
                x.transpose_(2, 3)
                x = self.max_pool2d(x)
                x = self.conv_transpose(x)
                x = self.conv_depthwise(x)
                x += torch.ones_like(x)
                x = self.adaptive_avg_pool(x)
                x = self.linear(x.flatten())
                return x

        model = ModelForMetatypeTesting()
        nncf_network = NNCFNetwork(model, [ModelInputInfo([1, 3, 300, 300])])
        nncf_graph = nncf_network.get_original_graph()

        for nncf_node in nncf_graph.get_all_nodes():  # type: NNCFNode
            assert nncf_node.node_name in ref_scope_vs_metatype_dict
            ref_metatype = ref_scope_vs_metatype_dict[nncf_node.node_name]
            assert nncf_node.metatype == ref_metatype
def test_compressed_graph_models_hw(desc, hw_config_type):
    model = desc.model_builder()
    config = get_basic_quantization_config_with_hw_config_type(hw_config_type.value,
                                                               input_sample_size=desc.input_sample_sizes)
    input_info_list = create_input_infos(config)
    compressed_model = NNCFNetwork(model, input_infos=input_info_list)

    # pylint:disable=protected-access
    quantization_builder = QuantizationBuilder(config, should_init=False)
    single_config_quantizer_setup = quantization_builder._get_quantizer_setup(compressed_model)
    sketch_graph = compressed_model.get_original_graph()

    potential_quantizer_graph = prepare_potential_quantizer_graph(sketch_graph, single_config_quantizer_setup)
    check_nx_graph(potential_quantizer_graph, desc.dot_filename, _case_dir(hw_config_type.value), sort_dot_graph=False)
Esempio n. 5
0
def check_correct_nncf_modules_replacement(model: NNCFNetwork, compressed_model: NNCFNetwork) \
        -> Tuple[Dict[Scope, Module], Dict[Scope, Module]]:
    """
    Check that all convolutions in model was replaced by NNCF convolution.
    :param model: original model
    :param compressed_model: compressed model
    :return: list of all convolutions in  original model and list of all NNCF convolutions from compressed model
    """
    NNCF_MODULES_REVERSED_MAP = {
        value: key
        for key, value in NNCF_MODULES_MAP.items()
    }
    original_modules = get_all_modules_by_type(model,
                                               list(NNCF_MODULES_MAP.values()))
    nncf_modules = get_all_modules_by_type(
        compressed_model.get_nncf_wrapped_model(),
        list(NNCF_MODULES_MAP.keys()))
    assert len(original_modules) == len(nncf_modules)
    print(original_modules, nncf_modules)
    for scope in original_modules.keys():
        sparse_scope = deepcopy(scope)
        elt = sparse_scope.pop()  # type: ScopeElement
        elt.calling_module_class_name = NNCF_MODULES_REVERSED_MAP[
            elt.calling_module_class_name]
        sparse_scope.push(elt)
        print(sparse_scope, nncf_modules)
        assert sparse_scope in nncf_modules
    return original_modules, nncf_modules
Esempio n. 6
0
def create_nncf_model_and_single_algo_builder(model: Module, config: NNCFConfig,
                                              dummy_forward_fn: Callable[[Module], Any] = None,
                                              wrap_inputs_fn: Callable[[Tuple, Dict], Tuple[Tuple, Dict]] = None) \
        -> Tuple[NNCFNetwork, PTCompressionAlgorithmController]:
    assert isinstance(config, NNCFConfig)
    NNCFConfig.validate(config)
    input_info_list = create_input_infos(config)
    scopes_without_shape_matching = config.get('scopes_without_shape_matching',
                                               [])
    ignored_scopes = config.get('ignored_scopes')
    target_scopes = config.get('target_scopes')

    compressed_model = NNCFNetwork(
        model,
        input_infos=input_info_list,
        dummy_forward_fn=dummy_forward_fn,
        wrap_inputs_fn=wrap_inputs_fn,
        ignored_scopes=ignored_scopes,
        target_scopes=target_scopes,
        scopes_without_shape_matching=scopes_without_shape_matching)

    algo_names = extract_algorithm_names(config)
    assert len(algo_names) == 1
    algo_name = next(iter(algo_names))
    builder_cls = PT_COMPRESSION_ALGORITHMS.get(algo_name)
    builder = builder_cls(config, should_init=True)
    return compressed_model, builder
Esempio n. 7
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        is_depthwise = is_prunable_depthwise_conv(node)
        node_module = model.get_containing_module(node.node_name)
        old_num_channels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
            old_num_channels = int(node_module.weight.size(0))
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_channels,
                          new_num_channels))
Esempio n. 8
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        output_mask = node.data['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Esempio n. 9
0
 def __init__(self, target_model: NNCFNetwork,
              sparsified_module_info: List[SparseModuleInfo]):
     super().__init__(target_model)
     self._loss = ZeroCompressionLoss(
         next(target_model.parameters()).device)
     self._scheduler = BaseCompressionScheduler()
     self.sparsified_module_info = sparsified_module_info
Esempio n. 10
0
def get_all_node_op_addresses_in_block(
        compressed_model: NNCFNetwork,
        block: BuildingBlock) -> Set[OperationAddress]:
    """
    Returns set of operation addresses of all layers included in the block.

    :param compressed_model: Target model.
    :param block: Building blocks.
    :return: Set of operation addresses for building block.
    """
    graph = compressed_model.get_original_graph()
    nx_graph = graph.get_nx_graph_copy()
    start_node, end_node = block
    start_node_key, end_node_key = None, None
    #pylint: disable=protected-access
    for node in nx_graph._node.values():
        if start_node == str(node['node_name']):
            start_node_key = node['key']
        if end_node == str(node['node_name']):
            end_node_key = node['key']
    simple_paths = nx.all_simple_paths(nx_graph, start_node_key, end_node_key)
    op_adresses = set()
    for node_keys_in_path in simple_paths:
        for node_key in node_keys_in_path:
            op_adresses.add(
                OperationAddress.from_str(
                    nx_graph._node[node_key]['node_name']))
    start_op_address = OperationAddress.from_str(
        nx_graph._node[start_node_key]['node_name'])
    op_adresses.remove(start_op_address)
    return op_adresses
Esempio n. 11
0
    def _binarize_weights_and_module_inputs(
            self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
        device = next(target_model.parameters()).device

        module_nodes = target_model.get_weighted_original_graph_nodes(
            nncf_module_names=self.compressed_nncf_module_names)

        insertion_commands = []
        for module_node in module_nodes:
            if not self._should_consider_scope(module_node.node_name):
                nncf_logger.info(
                    "Ignored adding binarizers in scope: {}".format(
                        module_node.node_name))
                continue

            nncf_logger.info("Adding Weight binarizer in scope: {}".format(
                module_node.node_name))
            op_weights = self.__create_binarize_module().to(device)

            nncf_logger.info("Adding Activation binarizer in scope: {}".format(
                module_node.node_name))

            compression_lr_multiplier = self.config.get_redefinable_global_param_value_for_algo(
                'compression_lr_multiplier', self.name)

            op_inputs = UpdateInputs(
                ActivationBinarizationScaleThreshold(
                    module_node.layer_attributes.get_weight_shape(),
                    compression_lr_multiplier=compression_lr_multiplier)).to(
                        device)

            ip_w = PTTargetPoint(TargetType.OPERATION_WITH_WEIGHTS,
                                 target_node_name=module_node.node_name)
            insertion_commands.append(
                PTInsertionCommand(
                    ip_w, op_weights,
                    TransformationPriority.QUANTIZATION_PRIORITY))

            ip_i = PTTargetPoint(TargetType.PRE_LAYER_OPERATION,
                                 target_node_name=module_node.node_name,
                                 input_port_id=0)
            insertion_commands.append(
                PTInsertionCommand(
                    ip_i, op_inputs,
                    TransformationPriority.QUANTIZATION_PRIORITY))
        return insertion_commands
Esempio n. 12
0
def get_bn_for_conv_node_by_name(
        target_model: NNCFNetwork,
        conv_node_name: NNCFNodeName) -> Optional[torch.nn.Module]:
    """
    Returns a batch norm module in target_model that corresponds immediately following a given
    convolution node in the model's NNCFGraph representation.
    :param target_model: NNCFNetwork to work with
    :param module_scope:
    :return: batch norm module
    """
    graph = target_model.get_original_graph()
    conv_node = graph.get_node_by_name(conv_node_name)
    bn_node = get_bn_node_for_conv(graph, conv_node)
    if bn_node is None:
        return None
    bn_module = target_model.get_containing_module(bn_node.node_name)
    return bn_module
Esempio n. 13
0
    def __init__(self, model: NNCFNetwork,
                 weight_quantizers: Dict[WeightQuantizerId,
                                         WeightQuantizerInfo],
                 constraints: HardwareQuantizationConstraints):
        self._wq_affected_module_node_name_vs_qid_dict = {
            k.target_node_name: k
            for k in weight_quantizers.keys()
        }
        self._quantizer_module_scope_vs_qid_dict = {
        }  # type: Dict[Scope, WeightQuantizerId]
        self._skipped_quantized_weight_node_names = []
        self._skipped_weight_quantizers = {
        }  # type: Dict[WeightQuantizerId, BaseQuantizer]
        self._weight_quantizers_in_execution_order_per_scope = OrderedDict(
        )  # type: Dict[Scope, BaseQuantizer]
        self._weight_quantizers_in_execution_order = OrderedDict(
        )  # type: Dict[WeightQuantizerId, BaseQuantizer]

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        weight_module_dict = model.get_nncf_wrapped_model()
        quantizers_in_execution_order_per_scope = get_all_modules_by_type(
            weight_module_dict, quantization_types)

        for scope, quantizer in quantizers_in_execution_order_per_scope.items(
        ):
            if self.is_wq_scope(scope):
                affected_module_scope = self.get_owning_module_scope_from_wq_scope(
                    scope)
                affected_module_node = model.get_original_graph(
                ).get_op_nodes_in_scope(affected_module_scope)[0]
                if affected_module_node.node_name in self._wq_affected_module_node_name_vs_qid_dict:
                    qid = self._wq_affected_module_node_name_vs_qid_dict[
                        affected_module_node.node_name]
                    if len(constraints.get_all_unique_bitwidths(qid)) != 1:
                        self._weight_quantizers_in_execution_order_per_scope[
                            scope] = quantizer
                        self._weight_quantizers_in_execution_order[
                            qid] = quantizer
                    else:
                        self._skipped_quantized_weight_node_names.append(
                            affected_module_node.node_name)
                        self._skipped_weight_quantizers[qid] = quantizer
Esempio n. 14
0
    def __init__(self, target_model: NNCFNetwork, config: NNCFConfig):
        super().__init__(target_model)

        self._loss = ZeroCompressionLoss(
            next(target_model.parameters()).device)
        scheduler_cls = QUANTIZATION_SCHEDULERS.get("staged")
        algo_config = extract_algo_specific_config(config, "binarization")
        self._scheduler = scheduler_cls(self, algo_config.get("params", {}))
        from nncf.torch.utils import is_main_process
        if is_main_process():
            self._compute_and_display_flops_binarization_rate()
Esempio n. 15
0
def test_pruning_node_selector(
        test_input_info_struct_: GroupPruningModulesTestStruct):
    model = test_input_info_struct_.model
    non_pruned_module_nodes = test_input_info_struct_.non_pruned_module_nodes
    pruned_groups_by_node_id = test_input_info_struct_.pruned_groups_by_node_id
    prune_first, prune_downsample = test_input_info_struct_.prune_params

    pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT]
    grouping_operations = PTElementwisePruningOp.get_all_op_aliases()
    from nncf.common.pruning.node_selector import PruningNodeSelector
    pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES,
                                                pruning_operations,
                                                grouping_operations, None,
                                                None, prune_first,
                                                prune_downsample)
    model = model()
    model.eval()
    nncf_network = NNCFNetwork(model,
                               input_infos=[ModelInputInfo([1, 1, 8, 8])])
    graph = nncf_network.get_original_graph()
    pruning_groups = pruning_node_selector.create_pruning_groups(graph)

    # 1. Check all not pruned modules
    all_pruned_nodes = pruning_groups.get_all_nodes()
    all_pruned_modules = [
        nncf_network.get_containing_module(node.node_name)
        for node in all_pruned_nodes
    ]
    for node_name in non_pruned_module_nodes:
        module = nncf_network.get_containing_module(node_name)
        assert module is not None and module not in all_pruned_modules

    # 2. Check that all pruned groups are valid
    for group_by_id in pruned_groups_by_node_id:
        first_node_id = group_by_id[0]
        cluster = pruning_groups.get_cluster_containing_element(first_node_id)
        cluster_node_ids = [n.node_id for n in cluster.elements]
        cluster_node_ids.sort()

        assert Counter(cluster_node_ids) == Counter(group_by_id)
Esempio n. 16
0
def create_test_quantization_env(model_creator=BasicConvTestModel,
                                 input_info_cfg=None) -> QuantizationEnv:
    if input_info_cfg is None:
        input_info_cfg = {"input_info": {"sample_size": [1, 1, 4, 4]}}

    model = model_creator()
    nncf_network = NNCFNetwork(model,
                               input_infos=create_input_infos(input_info_cfg))
    hw_config_type = HWConfigType.VPU
    hw_config_path = HWConfig.get_path_to_hw_config(hw_config_type)
    hw_config = PTHWConfig.from_json(hw_config_path)
    setup = PropagationBasedQuantizerSetupGenerator(
        NNCFConfig(), nncf_network, hw_config=hw_config).generate_setup()
    dummy_multi_setup = MultiConfigQuantizerSetup.from_single_config_setup(
        setup)
    for qp in dummy_multi_setup.quantization_points.values():
        qconf_constraint_list = []
        qconf = qp.possible_qconfigs[0]
        bit_set = [8, 4, 2] if 'conv' in str(qp.insertion_point) else [8, 4]
        for bits in bit_set:
            adj_qconf = deepcopy(qconf)
            adj_qconf.num_bits = bits
            qconf_constraint_list.append(adj_qconf)
        qp.possible_qconfigs = qconf_constraint_list
    experimental_builder = ExperimentalQuantizationBuilder(
        dummy_multi_setup, setup, {}, hw_config)
    experimental_builder.apply_to(nncf_network)
    # pylint:disable=line-too-long
    experimental_ctrl = experimental_builder.build_controller(nncf_network)
    data_loader = create_ones_mock_dataloader(input_info_cfg)
    constraints = HardwareQuantizationConstraints()
    for qid, qp_id_set in experimental_ctrl.module_id_to_qp_id_translation_dict.items(
    ):
        first_qp_id_for_this_quantizer_module = next(iter(qp_id_set))
        qconfigs = dummy_multi_setup.quantization_points[
            first_qp_id_for_this_quantizer_module].possible_qconfigs
        constraints.add(qid, qconfigs)

    return QuantizationEnv(nncf_network,
                           experimental_ctrl,
                           constraints,
                           data_loader,
                           lambda *x: 0,
                           hw_config_type=HWConfigType.VPU,
                           params=QuantizationEnvParams(
                               compression_ratio=0.15,
                               eval_subset_ratio=1.0,
                               skip_constraint=False,
                               performant_bw=False,
                               finetune=False,
                               bits=[2, 4, 8],
                               dump_init_precision_data=False))
Esempio n. 17
0
def test_weight_normed_modules_are_replaced_correctly():
    nncf_model = NNCFNetwork(WeightNormedConvModel(), input_infos=[ModelInputInfo([1, 1, 10])])

    wrapped_conv = nncf_model.conv
    assert hasattr(wrapped_conv, "weight_g")
    assert hasattr(wrapped_conv, "weight_v")
    assert hasattr(wrapped_conv, "weight")

    assert isinstance(wrapped_conv.weight_g, torch.nn.Parameter)
    assert isinstance(wrapped_conv.weight_v, torch.nn.Parameter)
    assert not isinstance(wrapped_conv.weight, torch.nn.Parameter)

    #pylint:disable=protected-access
    assert len(wrapped_conv._forward_pre_hooks) == 1
Esempio n. 18
0
def test_custom_module_registering():
    model = TwoConvTestModelWithUserModule()
    nncf_model = NNCFNetwork(model, input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork

    from nncf.torch.layers import UNWRAPPED_USER_MODULES
    assert ModuleOfUser in UNWRAPPED_USER_MODULES.registry_dict.values()

    # pylint: disable=protected-access
    assert isinstance(nncf_model.user_module, ModuleOfUser)
    assert isinstance(nncf_model.user_module, _NNCFModuleMixin)
    assert type(nncf_model.user_module).__name__ == "NNCFUserModuleOfUser"

    user_module_attrs = dir(nncf_model.user_module)
    for attr in dir(_NNCFModuleMixin):
        assert attr in user_module_attrs
Esempio n. 19
0
    def get_model_and_ctrl_with_applied_hw_config_quantization(
            model: torch.nn.Module,
            hw_config_dict: dict,
            should_be_quantize_inputs: bool = True):
        nncf_config = get_quantization_config_without_range_init(model_size=1)
        nncf_config["compression"].update(
            {"quantize_inputs": should_be_quantize_inputs})
        nncf_config["target_device"] = "ANY"  # for compatibility

        net = NNCFNetwork(model, input_infos=[ModelInputInfo([1, 2, 1, 1])])
        hw_config = PTHWConfig.from_dict(hw_config_dict)
        qbuilder = QuantizationBuilder(nncf_config, should_init=False)
        qbuilder.hw_config = hw_config
        net = qbuilder.apply_to(net)
        ctrl = qbuilder.build_controller(net)
        return net, ctrl
Esempio n. 20
0
def test_disable_shape_matching():
    class MatMulModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.dummy_param = torch.nn.Parameter(torch.ones([1]))

        def forward(self, inputs):
            half1, half2 = torch.chunk(inputs, 2, dim=2)
            return torch.bmm(half1, half2.transpose(1, 2))

    model = MatMulModel()

    input_shape_1 = (3, 32, 32)
    input_shape_2 = (4, 64, 64)

    qnet_no_shape = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo(input_shape_1), ],
                                scopes_without_shape_matching=['MatMulModel'])  # type: NNCFNetwork

    context = qnet_no_shape.get_tracing_context()
    context.enable_trace_dynamic_graph()
    _ = qnet_no_shape(torch.zeros(*input_shape_1))
    graph_1 = deepcopy(qnet_no_shape.get_dynamic_graph())

    _ = qnet_no_shape(torch.zeros(*input_shape_2))
    graph_2 = deepcopy(qnet_no_shape.get_dynamic_graph())

    assert graph_1 == graph_2

    nodes_1 = list(graph_1.get_all_nodes())
    assert len(nodes_1) == 5  # 1 input node + 1 chunk + 1 transpose + 1 matmul + 1 output node

    qnet = NNCFNetwork(model, input_infos=[ModelInputInfo(input_shape_1), ])  # type: NNCFNetwork
    context = qnet.get_tracing_context()
    context.enable_trace_dynamic_graph()
    _ = qnet(torch.zeros(*input_shape_1))
    _ = qnet(torch.zeros(*input_shape_2))
    # The second forward run should have led to an increase in registered node counts
    # since disable_shape_matching was False and the network was run with a different
    # shape of input tensor
    assert qnet.get_dynamic_graph().get_nodes_count() > graph_1.get_nodes_count()
Esempio n. 21
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.in_channels))
Esempio n. 22
0
def get_all_modules_in_blocks(
        compressed_model: NNCFNetwork,
        op_adresses_in_blocks: List[OperationAddress]
) -> List[torch.nn.Module]:
    """
    Returns set of all modules included in the block.

    :param compressed_model: Target model.
    :param op_adresses_in_blocks: Set of operation addresses for building block.
    :return: List of module for building block.
    """
    modules = []
    for op_address in op_adresses_in_blocks:
        if op_address.operator_name in NNCF_MODULES_OP_NAMES:
            modules.append(
                compressed_model.get_module_by_scope(
                    op_address.scope_in_model))
    return modules
Esempio n. 23
0
 def __init__(self, target_model: NNCFNetwork, prunable_types: List[str],
              pruned_module_groups_info: Clusterization[PrunedModuleInfo],
              config: NNCFConfig):
     super().__init__(target_model)
     self._loss = ZeroCompressionLoss(
         next(target_model.parameters()).device)
     self._prunable_types = prunable_types
     self.config = config
     self.pruning_config = extract_algo_specific_config(
         config, 'filter_pruning')
     params = self.pruning_config.get('params', {})
     self.pruned_module_groups_info = pruned_module_groups_info
     self.prune_batch_norms = params.get('prune_batch_norms', True)
     self.prune_first = params.get('prune_first_conv', False)
     self.prune_downsample_convs = params.get('prune_downsample_convs',
                                              False)
     self.prune_flops = False
     self.check_pruning_level(params)
     self._hooks = []
 def _handle_frozen_layers(self, target_model: NNCFNetwork):
     scopes_of_frozen_layers = []
     for weighted_node in target_model.get_weighted_original_graph_nodes():
         if not weighted_node.layer_attributes.weight_requires_grad:
             if self._should_consider_scope(weighted_node.node_name):
                 scopes_of_frozen_layers.append(weighted_node.node_name)
     scopes_to_print = '\n'.join(scopes_of_frozen_layers)
     if len(scopes_of_frozen_layers) > 0:
         is_allowed, reason = self._are_frozen_layers_allowed()
         if is_allowed:
             nncf_logger.warning(
                 '{}, compressing them without tuning weights.\n'
                 'Frozen layers:\n'
                 '{}'.format(reason, scopes_to_print))
         else:
             raise RuntimeError(
                 f'{reason}.\n'
                 f'Please unfreeze them or put into the Ignored Scope.\n'
                 f'Frozen Layers:\n'
                 f'{scopes_to_print}')
Esempio n. 25
0
    def output_prune(cls, model: NNCFNetwork, node: NNCFNode,
                     graph: NNCFGraph) -> None:
        mask = node.data['output_mask']
        if mask is None:
            return

        bool_mask = torch.tensor(mask, dtype=torch.bool)

        node_module = model.get_containing_module(node.node_name)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.out_channels = int(torch.sum(mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned Convolution {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          node_module.out_channels))
Esempio n. 26
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        node_module = model.get_containing_module(node.node_name)

        if isinstance(node_module, tuple(NNCF_WRAPPED_USER_MODULES_DICT)):
            assert node_module.target_weight_dim_for_compression == 0, \
                "Implemented only for target_weight_dim_for_compression == 0"
            old_num_clannels = int(node_module.weight.size(0))
            new_num_channels = int(torch.sum(input_mask))
            node_module.weight = torch.nn.Parameter(
                node_module.weight[bool_mask])
            node_module.n_channels = new_num_channels

            nncf_logger.info(
                'Pruned Elementwise {} by input mask. Old num features: {}, new num features:'
                ' {}.'.format(node.data['key'], old_num_clannels,
                              new_num_channels))
    def __init__(self, target_model: NNCFNetwork, original_model: nn.Module,
                 kd_type: str, scale: float, temperature: float):
        super().__init__()
        original_model.train()
        if kd_type == 'softmax':

            def kd_loss_fn(teacher_output: torch.Tensor,
                           student_output: torch.Tensor):
                if len(student_output.shape) != 2 or len(
                        teacher_output.shape) != 2:
                    nncf_logger.debug(
                        "Incompatible number of dimensions of the model output tensor for softmax KD "
                        "(student - {}, teacher - {}, number of dims {} should be == 2)"
                        " - ignoring!".format(student_output.shape,
                                              teacher_output.shape,
                                              len(student_output.shape)))
                    return torch.zeros([1]).to(student_output.device)
                return scale * -(nn.functional.log_softmax(student_output / temperature, dim=1) *
                                 nn.functional.softmax(teacher_output / temperature, dim=1)).mean() \
                       * (student_output.shape[1] * temperature * temperature)
        else:

            def kd_loss_fn(teacher_output: torch.Tensor,
                           student_output: torch.Tensor):
                mse = torch.nn.MSELoss()
                if len(teacher_output.shape) < 2:
                    nncf_logger.debug(
                        "Incompatible number of dimensions of the model output tensor for MSE KD "
                        "(student - {}, teacher - {}, number of dims {} should be > 1)"
                        " (most likely loss) - ignoring!".format(
                            student_output.shape, teacher_output.shape,
                            len(student_output.shape)))
                    return torch.zeros([1]).to(student_output.device)
                return scale * mse(teacher_output, student_output)

        self._kd_loss_handler = target_model.create_knowledge_distillation_loss_handler(
            original_model,
            partial(KnowledgeDistillationLoss._calculate,
                    kd_loss_fn=kd_loss_fn))
Esempio n. 28
0
    def input_prune(cls, model: NNCFNetwork, node: NNCFNode,
                    graph: NNCFGraph) -> None:
        input_mask = node.data['input_masks'][0]
        if input_mask is None:
            return

        node_module = model.get_containing_module(node.node_name)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_channels = new_num_channels
        node_module.num_groups = new_num_channels

        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned GroupNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(node.data['key'], old_num_clannels,
                          new_num_channels))
def test_gnmt_quantization(_case_config):
    model = GNMT(vocab_size=32)
    model = replace_lstm(model)
    forward_fn_ = gnmt_forward_fn(seq_len=10, batch_size=3, vocab_size=32)

    config = get_basic_quantization_config(_case_config.quant_type)
    config["input_info"] = [
        {
            "sample_size": [3, 10],
            "type": "long"
        },
        {
            "sample_size": [3],
            "type": "long"
        },
        {
            "sample_size": [3, 10],
            "type": "long"
        }
    ]
    config["compression"].update({
        "ignored_scopes": ["GNMT/ResidualRecurrentEncoder[encoder]/Embedding[embedder]",
                           "GNMT/ResidualRecurrentDecoder[decoder]/Embedding[embedder]"]})

    compressed_model = NNCFNetwork(model,
                                   input_infos=create_input_infos(config),
                                   dummy_forward_fn=forward_fn_,
                                   wrap_inputs_fn=gnmt_wrap_inputs_fn,
                                   scopes_without_shape_matching=
                                   ['GNMT/ResidualRecurrentDecoder[decoder]/RecurrentAttention[att_rnn]/'
                                    'BahdanauAttention[attn]'])

    builder = QuantizationBuilder(config, should_init=False)
    builder.apply_to(compressed_model)

    check_model_graph(compressed_model, 'gnmt_variable.dot', _case_config.graph_dir)
Esempio n. 30
0
def test_get_op_nodes_in_scope():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork
    nncf_graph = nncf_model.get_original_graph()

    # Valid scopes should be successfully found
    valid_nncf_modules = nncf_model.get_nncf_modules()
    nodes_list = list(nncf_graph.get_all_node_ids())
    for module_scope, _ in valid_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert len(matching_nncf_nodes) == 1
        node = matching_nncf_nodes[0]
        assert isinstance(node, NNCFNode)
        assert node.node_id in nodes_list

    fake_model = BasicConvTestModel()
    fake_nncf_model = NNCFNetwork(deepcopy(fake_model), input_infos=[ModelInputInfo([1, 1, 4, 4])])

    # Not valid scopes shouldn't be found
    fake_nncf_modules = fake_nncf_model.get_nncf_modules()
    for module_scope, _ in fake_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert not matching_nncf_nodes