def __init__(self, model, config, input_infos: ModelInputInfo = None, dummy_forward_fn=None, **kwargs): super().__init__(model, config, input_infos, dummy_forward_fn) self.sparsity_level = self.threshold = 0 self.ignored_scopes = self.config.get('ignored_scopes') self.target_scopes = self.config.get('target_scopes') self.dummy_forward_fn = dummy_forward_fn if self.dummy_forward_fn is None: self.dummy_forward_fn = create_dummy_forward_fn(input_infos) params = self.config.get("params", {}) device = next(model.parameters()).device self.weight_importance = WEIGHT_IMPORTANCE_FUNCTIONS.get( self.config.get('weight_importance', 'normed_abs')) self._replace_sparsifying_modules_by_nncf_modules( device, self.ignored_scopes, self.target_scopes, logger) self._register_weight_sparsifying_operations(device, self.ignored_scopes, self.target_scopes, logger) scheduler_cls = SPARSITY_SCHEDULERS.get( params.get("schedule", "polynomial")) self._scheduler = scheduler_cls(self, self.config)
def test_quantize_has_proper_is_weights_flag(): class Model(nn.Module): def __init__(self, size=1): super().__init__() self.size = size self.conv = nn.Conv2d(size, size, size) def forward(self, x): return self.conv(x) model = Model() reset_context('orig') reset_context('quantized_graphs') quant_model = QuantizedNetwork(model, create_quantize_module, inputs_shape=(1, 1, 2, 2), dummy_forward_fn=create_dummy_forward_fn( (1, 1, 2, 2))) for module in quant_model.modules(): if isinstance(module, NNCFConv2d): for op in module.pre_ops.values(): assert isinstance(op, (UpdateWeight, UpdateInputs)) assert op.operand.is_weights is isinstance(op, UpdateWeight) for _, aq in quant_model.activation_quantizers.items(): assert aq.is_weights is False
def test_quantize_network(self, model_name, model_builder, input_size, _quantize_config): net = model_builder() ctx = reset_context('orig') ctx = reset_context('quantized_graphs') qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_size, dummy_forward_fn=create_dummy_forward_fn(input_size)) _ = qnet(torch.zeros(*input_size)) _ = qnet(torch.zeros(*input_size)) check_graph(to_networkx(ctx), model_name, _quantize_config.graph_dir)
def test_output_quantization(_quantize_config): net = test_models.UNet() ctx = reset_context('orig') ctx = reset_context('quantized_graphs') input_shape = (1, 3, 360, 480) qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape, dummy_forward_fn=create_dummy_forward_fn(input_shape), quantize_outputs=True) _ = qnet(torch.zeros(*input_shape)) _ = qnet(torch.zeros(*input_shape)) check_graph(to_networkx(ctx), 'unet_qoutput.dot', _quantize_config.graph_dir)
def test_resnet18__with_ignore(_quantize_config): net = test_models.ResNet18() ctx = reset_context('orig') ctx = reset_context('quantized_graphs') input_shape = (1, 3, 32, 32) qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape, dummy_forward_fn=create_dummy_forward_fn(input_shape), ignored_scopes=['ResNet/Sequential[layer3]']) _ = qnet(torch.zeros(*input_shape)) _ = qnet(torch.zeros(*input_shape)) check_graph(to_networkx(ctx), 'resnet18_ignore.dot', _quantize_config.graph_dir)
def test_resnet18__with_not_qinput(_quantize_config): net = test_models.ResNet18() ctx = reset_context('orig') ctx = reset_context('quantized_graphs') input_shape = (1, 3, 32, 32) qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape, dummy_forward_fn=create_dummy_forward_fn(input_shape), quantize_inputs=False) _ = qnet(torch.zeros(*input_shape)) _ = qnet(torch.zeros(*input_shape)) check_graph(to_networkx(ctx), 'resnet18_no_qinput.dot', _quantize_config.graph_dir)
def test_custom_quantizable_subgraph_patterns(_quantize_config): net = test_models.SENet18() ctx = reset_context('orig') ctx = reset_context('quantized_graphs') input_shape = (1, 3, 32, 32) qnet = QuantizedNetwork(net, _quantize_config.quantizer, input_shape, dummy_forward_fn=create_dummy_forward_fn(input_shape), quantize_outputs=False, quantizable_subgraph_patterns=(("sigmoid", "__mul__"), ("__iadd__", "batch_norm"))) _ = qnet(torch.zeros(*input_shape)) _ = qnet(torch.zeros(*input_shape)) check_graph(to_networkx(ctx), 'senet_custom_patterns.dot', _quantize_config.graph_dir)
def get_all_node_names(model, input_sample_size, graph_scope=None, builder=None): if graph_scope is None: graph_scope = 'utils' reset_context(graph_scope) if not builder: builder = GraphBuilder( create_dummy_forward_fn([ ModelInputInfo(input_sample_size), ])) graph = builder.build_graph(model, graph_scope) return [ node_name.split(' ', 1)[1] for node_name in graph.get_all_node_keys() ]
def test_ambiguous_function(): class Model(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList( [nn.Conv2d(1, 1, 1), nn.Conv2d(1, 1, 1)]) def forward(self, x): for layer in self.layers: x = F.relu(layer(x)) reset_context('orig') reset_context('quantized_graphs') mod = Model() QuantizedNetwork(mod, create_quantize_module, inputs_shape=(1, 1, 1, 1), dummy_forward_fn=create_dummy_forward_fn((1, 1, 1, 1)))
def __init__(self, module, quantize_module_creator_fn, input_infos=None, dummy_forward_fn=None, ignored_scopes=None, target_scopes=None, quantize_inputs=True, quantize_outputs=False, quantizable_subgraph_patterns=None, scopes_without_shape_matching=None, disable_function_quantization_hooks=False): super().__init__() self.set_nncf_wrapped_module(module) self.quantize_inputs = quantize_inputs self.quantize_outputs = quantize_outputs self.input_infos = input_infos self.ignored_scopes = ignored_scopes self.target_scopes = target_scopes self.activation_quantizers = nn.ModuleDict() self.function_quantizers = nn.ModuleDict() self.quantized_weight_modules = OrderedDict() self.quantized_activation_modules = OrderedDict() self.quantize_module_creator_fn = quantize_module_creator_fn self.quantizable_subgraph_patterns = quantizable_subgraph_patterns self._dummy_forward_fn = dummy_forward_fn self._nncf_module_scopes = [] # type: List[Scope] self.debug_interface = QuantizationDebugInterface() if is_debug( ) else None self.scopes_without_shape_matching = scopes_without_shape_matching device = next(module.parameters()).device self.all_quantizations = OrderedDict() self._processed_input_agnostic_op_exec_contexts = set() self._processed_function_quantizers = set() # all modules should be replaced prior to graph building self._replace_quantized_modules_by_nncf_modules(device) self._register_weight_quantization_operations(device) if self._dummy_forward_fn is None: self._dummy_forward_fn = create_dummy_forward_fn(self.input_infos) self._graph_builder = GraphBuilder( custom_forward_fn=self._dummy_forward_fn) self._context_name = "orig" if self.scopes_without_shape_matching: get_context(self._context_name).add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._original_graph = self._graph_builder.build_graph( self, self._context_name) self._context_name = "quantized_graphs" self._ctx = get_context("quantized_graphs") if self.scopes_without_shape_matching: get_context(self._context_name).add_node_comparators( scopes_without_shape_matching, ShapeIgnoringTensorMetaComparator()) self._register_activation_quantization_hooks(device) if self.quantize_inputs: self._register_input_quantization_operations(device) if not disable_function_quantization_hooks: self._register_function_quantization_hooks(device) quantization_types = [ class_type.__name__ for class_type in QUANTIZATION_MODULES.registry_dict.values() ] self.all_quantizations = get_state_dict_names_with_modules( self, quantization_types) self.load_listener = LoadStateListener(self, self.all_quantizations) if self.debug_interface is not None: self.debug_interface.init_actual(self.all_quantizations.keys(), self.activation_quantizers.keys(), self.function_quantizers.keys())