示例#1
0
    def __init__(self,
                 model,
                 config,
                 input_infos: ModelInputInfo = None,
                 dummy_forward_fn=None,
                 **kwargs):
        super().__init__(model, config, input_infos, dummy_forward_fn)
        self.sparsity_level = self.threshold = 0

        self.ignored_scopes = self.config.get('ignored_scopes')
        self.target_scopes = self.config.get('target_scopes')
        self.dummy_forward_fn = dummy_forward_fn
        if self.dummy_forward_fn is None:
            self.dummy_forward_fn = create_dummy_forward_fn(input_infos)

        params = self.config.get("params", {})
        device = next(model.parameters()).device

        self.weight_importance = WEIGHT_IMPORTANCE_FUNCTIONS.get(
            self.config.get('weight_importance', 'normed_abs'))

        self._replace_sparsifying_modules_by_nncf_modules(
            device, self.ignored_scopes, self.target_scopes, logger)
        self._register_weight_sparsifying_operations(device,
                                                     self.ignored_scopes,
                                                     self.target_scopes,
                                                     logger)

        scheduler_cls = SPARSITY_SCHEDULERS.get(
            params.get("schedule", "polynomial"))
        self._scheduler = scheduler_cls(self, self.config)
def test_quantize_has_proper_is_weights_flag():
    class Model(nn.Module):
        def __init__(self, size=1):
            super().__init__()
            self.size = size
            self.conv = nn.Conv2d(size, size, size)

        def forward(self, x):
            return self.conv(x)

    model = Model()
    reset_context('orig')
    reset_context('quantized_graphs')

    quant_model = QuantizedNetwork(model,
                                   create_quantize_module,
                                   inputs_shape=(1, 1, 2, 2),
                                   dummy_forward_fn=create_dummy_forward_fn(
                                       (1, 1, 2, 2)))
    for module in quant_model.modules():
        if isinstance(module, NNCFConv2d):
            for op in module.pre_ops.values():
                assert isinstance(op, (UpdateWeight, UpdateInputs))
                assert op.operand.is_weights is isinstance(op, UpdateWeight)
    for _, aq in quant_model.activation_quantizers.items():
        assert aq.is_weights is False
 def test_quantize_network(self, model_name, model_builder, input_size, _quantize_config):
     net = model_builder()
     ctx = reset_context('orig')
     ctx = reset_context('quantized_graphs')
     qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_size,
                             dummy_forward_fn=create_dummy_forward_fn(input_size))
     _ = qnet(torch.zeros(*input_size))
     _ = qnet(torch.zeros(*input_size))
     check_graph(to_networkx(ctx), model_name, _quantize_config.graph_dir)
def test_output_quantization(_quantize_config):
    net = test_models.UNet()
    ctx = reset_context('orig')
    ctx = reset_context('quantized_graphs')
    input_shape = (1, 3, 360, 480)
    qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape,
                            dummy_forward_fn=create_dummy_forward_fn(input_shape),
                            quantize_outputs=True)
    _ = qnet(torch.zeros(*input_shape))
    _ = qnet(torch.zeros(*input_shape))

    check_graph(to_networkx(ctx), 'unet_qoutput.dot', _quantize_config.graph_dir)
def test_resnet18__with_ignore(_quantize_config):
    net = test_models.ResNet18()
    ctx = reset_context('orig')
    ctx = reset_context('quantized_graphs')
    input_shape = (1, 3, 32, 32)
    qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape,
                            dummy_forward_fn=create_dummy_forward_fn(input_shape),
                            ignored_scopes=['ResNet/Sequential[layer3]'])
    _ = qnet(torch.zeros(*input_shape))
    _ = qnet(torch.zeros(*input_shape))

    check_graph(to_networkx(ctx), 'resnet18_ignore.dot', _quantize_config.graph_dir)
def test_resnet18__with_not_qinput(_quantize_config):
    net = test_models.ResNet18()
    ctx = reset_context('orig')
    ctx = reset_context('quantized_graphs')
    input_shape = (1, 3, 32, 32)
    qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape,
                            dummy_forward_fn=create_dummy_forward_fn(input_shape),
                            quantize_inputs=False)
    _ = qnet(torch.zeros(*input_shape))
    _ = qnet(torch.zeros(*input_shape))

    check_graph(to_networkx(ctx), 'resnet18_no_qinput.dot', _quantize_config.graph_dir)
def test_custom_quantizable_subgraph_patterns(_quantize_config):
    net = test_models.SENet18()
    ctx = reset_context('orig')
    ctx = reset_context('quantized_graphs')
    input_shape = (1, 3, 32, 32)
    qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape,
                            dummy_forward_fn=create_dummy_forward_fn(input_shape),
                            quantize_outputs=False,
                            quantizable_subgraph_patterns=(("sigmoid", "__mul__"),
                                                           ("__iadd__", "batch_norm")))
    _ = qnet(torch.zeros(*input_shape))
    _ = qnet(torch.zeros(*input_shape))

    check_graph(to_networkx(ctx), 'senet_custom_patterns.dot', _quantize_config.graph_dir)
def get_all_node_names(model,
                       input_sample_size,
                       graph_scope=None,
                       builder=None):
    if graph_scope is None:
        graph_scope = 'utils'
    reset_context(graph_scope)
    if not builder:
        builder = GraphBuilder(
            create_dummy_forward_fn([
                ModelInputInfo(input_sample_size),
            ]))
    graph = builder.build_graph(model, graph_scope)
    return [
        node_name.split(' ', 1)[1] for node_name in graph.get_all_node_keys()
    ]
def test_ambiguous_function():
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers = nn.ModuleList(
                [nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)])

        def forward(self, x):
            for layer in self.layers:
                x = F.relu(layer(x))

    reset_context('orig')
    reset_context('quantized_graphs')
    mod = Model()
    QuantizedNetwork(mod,
                     create_quantize_module,
                     inputs_shape=(1, 1, 1, 1),
                     dummy_forward_fn=create_dummy_forward_fn((1, 1, 1, 1)))
    def __init__(self,
                 module,
                 quantize_module_creator_fn,
                 input_infos=None,
                 dummy_forward_fn=None,
                 ignored_scopes=None,
                 target_scopes=None,
                 quantize_inputs=True,
                 quantize_outputs=False,
                 quantizable_subgraph_patterns=None,
                 scopes_without_shape_matching=None,
                 disable_function_quantization_hooks=False):
        super().__init__()
        self.set_nncf_wrapped_module(module)
        self.quantize_inputs = quantize_inputs
        self.quantize_outputs = quantize_outputs
        self.input_infos = input_infos
        self.ignored_scopes = ignored_scopes
        self.target_scopes = target_scopes
        self.activation_quantizers = nn.ModuleDict()
        self.function_quantizers = nn.ModuleDict()
        self.quantized_weight_modules = OrderedDict()
        self.quantized_activation_modules = OrderedDict()
        self.quantize_module_creator_fn = quantize_module_creator_fn
        self.quantizable_subgraph_patterns = quantizable_subgraph_patterns
        self._dummy_forward_fn = dummy_forward_fn
        self._nncf_module_scopes = []  # type: List[Scope]
        self.debug_interface = QuantizationDebugInterface() if is_debug(
        ) else None
        self.scopes_without_shape_matching = scopes_without_shape_matching

        device = next(module.parameters()).device

        self.all_quantizations = OrderedDict()
        self._processed_input_agnostic_op_exec_contexts = set()
        self._processed_function_quantizers = set()

        # all modules should be replaced prior to graph building
        self._replace_quantized_modules_by_nncf_modules(device)
        self._register_weight_quantization_operations(device)

        if self._dummy_forward_fn is None:
            self._dummy_forward_fn = create_dummy_forward_fn(self.input_infos)

        self._graph_builder = GraphBuilder(
            custom_forward_fn=self._dummy_forward_fn)

        self._context_name = "orig"
        if self.scopes_without_shape_matching:
            get_context(self._context_name).add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._original_graph = self._graph_builder.build_graph(
            self, self._context_name)

        self._context_name = "quantized_graphs"
        self._ctx = get_context("quantized_graphs")
        if self.scopes_without_shape_matching:
            get_context(self._context_name).add_node_comparators(
                scopes_without_shape_matching,
                ShapeIgnoringTensorMetaComparator())

        self._register_activation_quantization_hooks(device)
        if self.quantize_inputs:
            self._register_input_quantization_operations(device)

        if not disable_function_quantization_hooks:
            self._register_function_quantization_hooks(device)

        quantization_types = [
            class_type.__name__
            for class_type in QUANTIZATION_MODULES.registry_dict.values()
        ]
        self.all_quantizations = get_state_dict_names_with_modules(
            self, quantization_types)
        self.load_listener = LoadStateListener(self, self.all_quantizations)
        if self.debug_interface is not None:
            self.debug_interface.init_actual(self.all_quantizations.keys(),
                                             self.activation_quantizers.keys(),
                                             self.function_quantizers.keys())