def test_build_graph(self, desc: ModelDesc):
     net = desc.model_builder()
     input_sample_sizes = desc.input_sample_sizes
     if isinstance(input_sample_sizes, tuple):
         input_info_list = [ModelInputInfo(sample_size) for sample_size in input_sample_sizes]
     else:
         input_info_list = [ModelInputInfo(input_sample_sizes)]
     dummy_forward_fn = desc.dummy_forward_fn
     if not dummy_forward_fn:
         dummy_forward_fn = create_dummy_forward_fn(input_info_list, desc.wrap_inputs_fn)
     graph_builder = GraphBuilder(custom_forward_fn=dummy_forward_fn)
     graph = graph_builder.build_graph(net)
     check_graph(graph, desc.dot_filename, 'original')
    def test_operator_metatype_marking(self):
        from nncf.torch.graph.operator_metatypes import PTConv2dMetatype, PTBatchNormMetatype, PTRELUMetatype, \
            PTMaxPool2dMetatype, PTTransposeMetatype, \
            PTConvTranspose2dMetatype, PTDepthwiseConv2dSubtype, PTAddMetatype, PTAvgPool2dMetatype, PTLinearMetatype
        ref_scope_vs_metatype_dict = {
            "/" + MODEL_INPUT_OP_NAME + "_0": PTInputNoopMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_regular]/conv2d_0": PTConv2dMetatype,
            "ModelForMetatypeTesting/NNCFBatchNorm[bn]/batch_norm_0": PTBatchNormMetatype,
            "ModelForMetatypeTesting/relu_0": PTRELUMetatype,
            "ModelForMetatypeTesting/transpose__0": PTTransposeMetatype,
            "ModelForMetatypeTesting/MaxPool2d[max_pool2d]/max_pool2d_0": PTMaxPool2dMetatype,
            "ModelForMetatypeTesting/NNCFConvTranspose2d[conv_transpose]/conv_transpose2d_0": PTConvTranspose2dMetatype,
            "ModelForMetatypeTesting/NNCFConv2d[conv_depthwise]/conv2d_0": PTDepthwiseConv2dSubtype,
            "ModelForMetatypeTesting/__iadd___0": PTAddMetatype,
            "ModelForMetatypeTesting/AdaptiveAvgPool2d[adaptive_avg_pool]/adaptive_avg_pool2d_0": PTAvgPool2dMetatype,
            "ModelForMetatypeTesting/NNCFLinear[linear]/linear_0": PTLinearMetatype,
            'ModelForMetatypeTesting/flatten_0': PTReshapeMetatype,
            "/" + MODEL_OUTPUT_OP_NAME + "_0": PTOutputNoopMetatype,
        }

        class ModelForMetatypeTesting(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv_regular = torch.nn.Conv2d(in_channels=3,
                                                    out_channels=16,
                                                    kernel_size=3)
                self.bn = torch.nn.BatchNorm2d(num_features=16)
                self.max_pool2d = torch.nn.MaxPool2d(kernel_size=2)
                self.conv_transpose = torch.nn.ConvTranspose2d(in_channels=16,
                                                               out_channels=8,
                                                               kernel_size=3)
                self.conv_depthwise = torch.nn.Conv2d(in_channels=8, out_channels=8,
                                                      kernel_size=5, groups=8)
                self.adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d(output_size=1)
                self.linear = torch.nn.Linear(in_features=8, out_features=1)

            def forward(self, input_):
                x = self.conv_regular(input_)
                x = self.bn(x)
                x = torch.nn.functional.relu(x)
                x.transpose_(2, 3)
                x = self.max_pool2d(x)
                x = self.conv_transpose(x)
                x = self.conv_depthwise(x)
                x += torch.ones_like(x)
                x = self.adaptive_avg_pool(x)
                x = self.linear(x.flatten())
                return x

        model = ModelForMetatypeTesting()
        nncf_network = NNCFNetwork(model, [ModelInputInfo([1, 3, 300, 300])])
        nncf_graph = nncf_network.get_original_graph()

        for nncf_node in nncf_graph.get_all_nodes():  # type: NNCFNode
            assert nncf_node.node_name in ref_scope_vs_metatype_dict
            ref_metatype = ref_scope_vs_metatype_dict[nncf_node.node_name]
            assert nncf_node.metatype == ref_metatype
def test_disable_shape_matching():
    class MatMulModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.dummy_param = torch.nn.Parameter(torch.ones([1]))

        def forward(self, inputs):
            half1, half2 = torch.chunk(inputs, 2, dim=2)
            return torch.bmm(half1, half2.transpose(1, 2))

    model = MatMulModel()

    input_shape_1 = (3, 32, 32)
    input_shape_2 = (4, 64, 64)

    qnet_no_shape = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo(input_shape_1), ],
                                scopes_without_shape_matching=['MatMulModel'])  # type: NNCFNetwork

    context = qnet_no_shape.get_tracing_context()
    context.enable_trace_dynamic_graph()
    _ = qnet_no_shape(torch.zeros(*input_shape_1))
    graph_1 = deepcopy(qnet_no_shape.get_dynamic_graph())

    _ = qnet_no_shape(torch.zeros(*input_shape_2))
    graph_2 = deepcopy(qnet_no_shape.get_dynamic_graph())

    assert graph_1 == graph_2

    nodes_1 = list(graph_1.get_all_nodes())
    assert len(nodes_1) == 5  # 1 input node + 1 chunk + 1 transpose + 1 matmul + 1 output node

    qnet = NNCFNetwork(model, input_infos=[ModelInputInfo(input_shape_1), ])  # type: NNCFNetwork
    context = qnet.get_tracing_context()
    context.enable_trace_dynamic_graph()
    _ = qnet(torch.zeros(*input_shape_1))
    _ = qnet(torch.zeros(*input_shape_2))
    # The second forward run should have led to an increase in registered node counts
    # since disable_shape_matching was False and the network was run with a different
    # shape of input tensor
    assert qnet.get_dynamic_graph().get_nodes_count() > graph_1.get_nodes_count()
def test_get_op_nodes_in_scope():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(deepcopy(model), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork
    nncf_graph = nncf_model.get_original_graph()

    # Valid scopes should be successfully found
    valid_nncf_modules = nncf_model.get_nncf_modules()
    nodes_list = list(nncf_graph.get_all_node_ids())
    for module_scope, _ in valid_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert len(matching_nncf_nodes) == 1
        node = matching_nncf_nodes[0]
        assert isinstance(node, NNCFNode)
        assert node.node_id in nodes_list

    fake_model = BasicConvTestModel()
    fake_nncf_model = NNCFNetwork(deepcopy(fake_model), input_infos=[ModelInputInfo([1, 1, 4, 4])])

    # Not valid scopes shouldn't be found
    fake_nncf_modules = fake_nncf_model.get_nncf_modules()
    for module_scope, _ in fake_nncf_modules.items():
        matching_nncf_nodes = nncf_graph.get_op_nodes_in_scope(module_scope)
        assert not matching_nncf_nodes
def test_weight_normed_modules_are_replaced_correctly():
    nncf_model = NNCFNetwork(WeightNormedConvModel(), input_infos=[ModelInputInfo([1, 1, 10])])

    wrapped_conv = nncf_model.conv
    assert hasattr(wrapped_conv, "weight_g")
    assert hasattr(wrapped_conv, "weight_v")
    assert hasattr(wrapped_conv, "weight")

    assert isinstance(wrapped_conv.weight_g, torch.nn.Parameter)
    assert isinstance(wrapped_conv.weight_v, torch.nn.Parameter)
    assert not isinstance(wrapped_conv.weight, torch.nn.Parameter)

    #pylint:disable=protected-access
    assert len(wrapped_conv._forward_pre_hooks) == 1
def test_custom_module_registering():
    model = TwoConvTestModelWithUserModule()
    nncf_model = NNCFNetwork(model, input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork

    from nncf.torch.layers import UNWRAPPED_USER_MODULES
    assert ModuleOfUser in UNWRAPPED_USER_MODULES.registry_dict.values()

    # pylint: disable=protected-access
    assert isinstance(nncf_model.user_module, ModuleOfUser)
    assert isinstance(nncf_model.user_module, _NNCFModuleMixin)
    assert type(nncf_model.user_module).__name__ == "NNCFUserModuleOfUser"

    user_module_attrs = dir(nncf_model.user_module)
    for attr in dir(_NNCFModuleMixin):
        assert attr in user_module_attrs
Example #7
0
    def get_model_and_ctrl_with_applied_hw_config_quantization(
            model: torch.nn.Module,
            hw_config_dict: dict,
            should_be_quantize_inputs: bool = True):
        nncf_config = get_quantization_config_without_range_init(model_size=1)
        nncf_config["compression"].update(
            {"quantize_inputs": should_be_quantize_inputs})
        nncf_config["target_device"] = "ANY"  # for compatibility

        net = NNCFNetwork(model, input_infos=[ModelInputInfo([1, 2, 1, 1])])
        hw_config = PTHWConfig.from_dict(hw_config_dict)
        qbuilder = QuantizationBuilder(nncf_config, should_init=False)
        qbuilder.hw_config = hw_config
        net = qbuilder.apply_to(net)
        ctrl = qbuilder.build_controller(net)
        return net, ctrl
Example #8
0
def test_pruning_node_selector(
        test_input_info_struct_: GroupPruningModulesTestStruct):
    model = test_input_info_struct_.model
    non_pruned_module_nodes = test_input_info_struct_.non_pruned_module_nodes
    pruned_groups_by_node_id = test_input_info_struct_.pruned_groups_by_node_id
    prune_first, prune_downsample = test_input_info_struct_.prune_params

    pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT]
    grouping_operations = PTElementwisePruningOp.get_all_op_aliases()
    from nncf.common.pruning.node_selector import PruningNodeSelector
    pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES,
                                                pruning_operations,
                                                grouping_operations, None,
                                                None, prune_first,
                                                prune_downsample)
    model = model()
    model.eval()
    nncf_network = NNCFNetwork(model,
                               input_infos=[ModelInputInfo([1, 1, 8, 8])])
    graph = nncf_network.get_original_graph()
    pruning_groups = pruning_node_selector.create_pruning_groups(graph)

    # 1. Check all not pruned modules
    all_pruned_nodes = pruning_groups.get_all_nodes()
    all_pruned_modules = [
        nncf_network.get_containing_module(node.node_name)
        for node in all_pruned_nodes
    ]
    for node_name in non_pruned_module_nodes:
        module = nncf_network.get_containing_module(node_name)
        assert module is not None and module not in all_pruned_modules

    # 2. Check that all pruned groups are valid
    for group_by_id in pruned_groups_by_node_id:
        first_node_id = group_by_id[0]
        cluster = pruning_groups.get_cluster_containing_element(first_node_id)
        cluster_node_ids = [n.node_id for n in cluster.elements]
        cluster_node_ids.sort()

        assert Counter(cluster_node_ids) == Counter(group_by_id)
def forward(arg1=None, arg2=None, arg3=None, arg4=None, arg5=TENSOR_DEFAULT):
    pass


class InputWrappingTestStruct:
    def __init__(self, input_infos, model_args, model_kwargs,
                 ref_wrapping_sequence):
        self.input_infos = input_infos
        self.model_args = model_args
        self.model_kwargs = model_kwargs
        self.ref_wrapping_sequence = ref_wrapping_sequence


INPUT_WRAPPING_TEST_CASES = [
    InputWrappingTestStruct(
        input_infos=[ModelInputInfo([1])],
        model_args=(TENSOR_1, ),
        model_kwargs={},
        ref_wrapping_sequence=[TENSOR_1],
    ),
    InputWrappingTestStruct(
        input_infos=[ModelInputInfo([1], keyword="arg2")],
        model_args=(),
        model_kwargs={"arg2": TENSOR_2},
        ref_wrapping_sequence=[TENSOR_2],
    ),
    InputWrappingTestStruct(
        input_infos=[
            ModelInputInfo([1]),
            ModelInputInfo([1]),
            ModelInputInfo([1], keyword="arg3"),
Example #10
0
 def setup(self):
     self.compressed_model = NNCFNetwork(InsertionPointTestModel(),
                                         [ModelInputInfo([1, 1, 10, 10])])  # type: NNCFNetwork
Example #11
0
def test_check_correct_modules_replacement():
    model = TwoConvTestModel()
    nncf_model = NNCFNetwork(TwoConvTestModel(), input_infos=[ModelInputInfo([1, 1, 4, 4])])  # type: NNCFNetwork

    _, nncf_modules = check_correct_nncf_modules_replacement(model, nncf_model)
    assert set(nncf_modules) == set(nncf_model.get_nncf_modules())