示例#1
0
    def get_average_traces(self, max_iter=100, tolerance=1e-3) -> Tensor:
        """
        Estimates average hessian trace for each parameter
        :param max_iter: maximum number of iterations for Hutchinson algorithm
        :param tolerance: - minimum relative tolerance for stopping the algorithm.
        It's calculated  between mean average trace from previous iteration and current one.
        :return: Tensor with average hessian trace per parameter
        """
        avg_total_trace = 0.
        avg_traces_per_iter = []  # type: List[Tensor]
        mean_avg_traces_per_param = None

        for i in range(max_iter):
            avg_traces_per_iter.append(self._calc_avg_traces_per_param())

            mean_avg_traces_per_param = self._get_mean(avg_traces_per_iter)
            mean_avg_total_trace = torch.sum(mean_avg_traces_per_param)

            diff_avg = abs(mean_avg_total_trace - avg_total_trace) / (
                avg_total_trace + self._diff_eps)
            if diff_avg < tolerance:
                return mean_avg_traces_per_param
            avg_total_trace = mean_avg_total_trace
            nncf_logger.info('{}# difference_avg={} avg_trace={}'.format(
                i, diff_avg, avg_total_trace))

        return mean_avg_traces_per_param
示例#2
0
    def output_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                     nx_graph: nx.DiGraph):
        output_mask = nx_node['output_mask']
        if output_mask is None:
            return

        bool_mask = torch.tensor(output_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(bool_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(1))

        in_channels = node_module.weight.size(0)
        broadcasted_mask = bool_mask.repeat(in_channels).view(
            in_channels, bool_mask.size(0))
        new_weight_shape = list(node_module.weight.shape)
        new_weight_shape[1] = new_num_channels

        node_module.out_channels = new_num_channels
        node_module.weight = torch.nn.Parameter(
            node_module.weight[broadcasted_mask].view(new_weight_shape))

        if node_module.bias is not None:
            node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by pruning mask. Old output filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.out_channels))
def load_state(model: torch.nn.Module, saved_state_dict: dict, is_resume: bool = False) -> int:
    """
    Used to load a checkpoint containing a compressed model into an NNCFNetwork object, but can
    be used for any PyTorch module as well. Will do matching of saved_state_dict parameters to
    the model's state_dict parameters while discarding irrelevant prefixes added during wrapping
    in NNCFNetwork or DataParallel/DistributedDataParallel objects, and load the matched parameters
    from the saved_state_dict into the model's state dict.
    :param model: The target module for the saved_state_dict to be loaded to.
    :param saved_state_dict: A state dict containing the parameters to be loaded into the model.
    :param is_resume: Determines the behavior when the function cannot do a successful parameter match
    when loading. If True, the function will raise an exception if it cannot match the saved_state_dict
    parameters to the model's parameters (i.e. if some parameters required by model are missing in
    saved_state_dict, or if saved_state_dict has parameters that could not be matched to model parameters,
    or if the shape of parameters is not matching). If False, the exception won't be raised.
    Usually is_resume is specified as False when loading uncompressed model's weights into the model with
    compression algorithms already applied, and as True when loading a compressed model's weights into the model
    with compression algorithms applied to evaluate the model.
    :return: The number of saved_state_dict entries successfully matched and loaded into model.
    """
    def key_normalizer(key):
        new_key = key
        match = re.search('(pre_ops|post_ops)\\.(\\d+?)\\.op', key)
        return new_key if not match else new_key.replace(match.group(), 'operation')

    if 'state_dict' in saved_state_dict:
        saved_state_dict = saved_state_dict['state_dict']
    state_dict = model.state_dict()

    new_dict, num_loaded_layers, problematic_keys = match_keys(is_resume, saved_state_dict, state_dict, key_normalizer)
    num_saved_layers = len(saved_state_dict.items())
    process_problematic_keys(is_resume, problematic_keys, num_loaded_layers == num_saved_layers)
    nncf_logger.info("Loaded {}/{} layers".format(num_loaded_layers, len(state_dict.items())))

    model.load_state_dict(new_dict, strict=False)
    return num_loaded_layers
示例#4
0
    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        sign = torch.any(torch.lt(min_values, 0))
        if self.signedness_to_force is not None and sign != self.signedness_to_force:
            nncf_logger.warning("Forcing signed to {} for module {}".format(
                self.signedness_to_force, log_module_name))
            sign = self.signedness_to_force
        self.signed = int(sign)

        abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
        SCALE_LOWER_THRESHOLD = 0.1
        mask = torch.gt(abs_max, SCALE_LOWER_THRESHOLD)
        self._scale_param_storage.data = torch.where(
            mask, abs_max,
            SCALE_LOWER_THRESHOLD * torch.ones_like(self._scale_param_storage))
        if self._is_using_log_scale_storage:
            self._scale_param_storage.data.log_()

        nncf_logger.info("Set sign: {} and scale: {} for {}".format(
            self.signed, get_flat_tensor_contents_string(self.scale),
            log_module_name))
示例#5
0
    def _sparsify_weights(self,
                          target_model: NNCFNetwork) -> List[InsertionCommand]:
        device = next(target_model.parameters()).device
        sparsified_modules = target_model.get_nncf_modules_by_module_names(
            self.compressed_nncf_module_names)
        insertion_commands = []
        for module_scope, module in sparsified_modules.items():
            scope_str = str(module_scope)

            if not self._should_consider_scope(scope_str):
                nncf_logger.info(
                    "Ignored adding Weight Sparsifier in scope: {}".format(
                        scope_str))
                continue

            nncf_logger.info(
                "Adding Weight Sparsifier in scope: {}".format(scope_str))
            operation = self.create_weight_sparsifying_operation(module)
            hook = UpdateWeight(operation).to(device)
            insertion_commands.append(
                InsertionCommand(
                    InsertionPoint(InsertionType.NNCF_MODULE_PRE_OP,
                                   module_scope=module_scope), hook,
                    OperationPriority.SPARSIFICATION_PRIORITY))
            self._sparsified_module_info.append(
                SparseModuleInfo(scope_str, module, hook.operand))

        return insertion_commands
示例#6
0
 def apply_mask(self):
     """
     Applying propagated masks for all nodes in topological order:
     1. running input_prune method for this node
     2. running output_prune method for this node
     """
     sorted_nodes = [
         self.nx_graph.nodes[name]
         for name in nx.topological_sort(self.nx_graph)
     ]
     pruned_node_modules = list()
     with torch.no_grad():
         for node in sorted_nodes:
             node_type = self.graph.node_type_fn(node)
             node_cls = self.get_class_by_type_name(node_type)()
             nncf_node = self.graph._nx_node_to_nncf_node(node)
             node_module = self.model.get_module_by_scope(
                 nncf_node.op_exec_context.scope_in_model)
             # Some modules can be associated with several nodes
             if node_module not in pruned_node_modules:
                 node_cls.input_prune(self.model, node, self.graph,
                                      self.nx_graph)
                 node_cls.output_prune(self.model, node, self.graph,
                                       self.nx_graph)
                 pruned_node_modules.append(node_module)
     nncf_logger.info('Finished mask applying step')
示例#7
0
    def run_batchnorm_adaptation(self, config):
        initializer_params = config.get("initializer", {})
        init_bn_adapt_config = initializer_params.get('batchnorm_adaptation', {})
        num_bn_adaptation_samples = init_bn_adapt_config.get('num_bn_adaptation_samples', 0)
        num_bn_forget_samples = init_bn_adapt_config.get('num_bn_forget_samples', 0)
        try:
            bn_adaptation_args = config.get_extra_struct(BNAdaptationInitArgs)
            has_bn_adapt_init_args = True
        except KeyError:
            has_bn_adapt_init_args = False

        if not init_bn_adapt_config:
            if has_bn_adapt_init_args:
                nncf_logger.warning("Enabling quantization batch norm adaptation with default parameters.")
                num_bn_adaptation_samples = 2000
                num_bn_forget_samples = 1000

        if num_bn_adaptation_samples < 0:
            raise AttributeError('Number of adaptation samples must be >= 0')
        if num_bn_adaptation_samples > 0:
            if not has_bn_adapt_init_args:
                nncf_logger.info(
                    'Could not run batchnorm adaptation '
                    'as the adaptation data loader is not provided as an extra struct. '
                    'Refer to `NNCFConfig.register_extra_structs` and the `BNAdaptationInitArgs` class')
                return
            batch_size = bn_adaptation_args.data_loader.batch_size
            num_bn_forget_steps = numpy.ceil(num_bn_forget_samples / batch_size)
            num_bn_adaptation_steps = numpy.ceil(num_bn_adaptation_samples / batch_size)
            bn_adaptation_runner = DataLoaderBNAdaptationRunner(self._model, bn_adaptation_args.device,
                                                                num_bn_forget_steps)
            bn_adaptation_runner.run(bn_adaptation_args.data_loader, num_bn_adaptation_steps)
示例#8
0
    def propagate_can_prune_attr_down(self):
        """
        Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned
        branches.
        """
        sorted_nodes = [
            self.nx_graph.nodes[name]
            for name in nx.topological_sort(self.nx_graph)
        ]
        for node in sorted_nodes:
            # Propagate attribute only in not conv case
            if self.node_propagate_can_prune_attr(node['key']):
                in_edges = self.nx_graph.in_edges(node['key'])
                can_prune = all(
                    self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR]
                    for key, _ in in_edges)
                can_prune_any = any(
                    self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR]
                    for key, _ in in_edges)

                if (not self.node_accept_different_inputs(node) and not can_prune) or \
                        (self.node_accept_different_inputs(node) and not can_prune_any):
                    node[ModelPruner.CAN_PRUNE_ATTR] = can_prune

        nncf_logger.info('Propagated can_prune attribute down')
示例#9
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)

        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        old_num_clannels = int(node_module.weight.size(0))
        new_num_channels = int(torch.sum(input_mask))

        node_module.num_features = new_num_channels
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])
        node_module.bias = torch.nn.Parameter(node_module.bias[bool_mask])
        node_module.running_mean = torch.nn.Parameter(
            node_module.running_mean[bool_mask], requires_grad=False)
        node_module.running_var = torch.nn.Parameter(
            node_module.running_var[bool_mask], requires_grad=False)

        nncf_logger.info(
            'Pruned BatchNorm {} by input mask. Old num features: {}, new num features:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
示例#10
0
    def run_batchnorm_adaptation(self, config):
        initializer_params = config.get("initializer", {})
        init_bn_adapt_config = initializer_params.get('batchnorm_adaptation',
                                                      {})
        num_bn_adaptation_steps = init_bn_adapt_config.get(
            'num_bn_adaptation_steps', 0)
        num_bn_forget_steps = init_bn_adapt_config.get('num_bn_forget_steps',
                                                       5)

        if num_bn_adaptation_steps < 0:
            raise AttributeError(
                'Number of batch adaptation steps must be >= 0')
        if num_bn_adaptation_steps > 0:
            try:
                bn_adaptation_args = config.get_extra_struct(
                    BNAdaptationInitArgs)
            except KeyError:
                nncf_logger.info(
                    'Could not run batchnorm adaptation '
                    'as the adaptation data loader is not provided as an extra struct. '
                    'Refer to `NNCFConfig.register_extra_structs` and the `BNAdaptationInitArgs` class'
                )
                return

            bn_adaptation_runner = DataLoaderBNAdaptationRunner(
                self._model, bn_adaptation_args.device, num_bn_forget_steps)
            bn_adaptation_runner.run(bn_adaptation_args.data_loader,
                                     num_bn_adaptation_steps)
示例#11
0
    def apply_mask(self):
        """
        Applying propagated masks for all nodes in topological order:
        if all inputs of node can_prune -> running input_prune method for this node
        if node[ModelPruner.CAN_PRUNE_ATTR] -> running output_prune method for this node
        """
        sorted_nodes = [
            self.nx_graph.nodes[name]
            for name in nx.topological_sort(self.nx_graph)
        ]
        with torch.no_grad():
            for node in sorted_nodes:
                node_type = self.graph.node_type_fn(node)
                node_cls = self.get_class_by_type_name(node_type)()

                in_edges = self.nx_graph.in_edges(node['key'])
                can_prune_input = all(
                    self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR]
                    for key, _ in in_edges)
                if can_prune_input:
                    node_cls.input_prune(self.model, node, self.graph,
                                         self.nx_graph)

                if node[ModelPruner.CAN_PRUNE_ATTR]:
                    node_cls.output_prune(self.model, node, self.graph,
                                          self.nx_graph)
        nncf_logger.info('Finished mask applying step')
示例#12
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)
        new_num_channels = int(torch.sum(input_mask))

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        is_depthwise = nx_node['is_depthwise']
        old_num_clannels = int(node_module.weight.size(1))

        if is_depthwise:
            # In depthwise case prune output channels by input mask, here only fix for new number of input channels
            node_module.groups = new_num_channels
            node_module.in_channels = new_num_channels
        else:
            out_channels = node_module.weight.size(0)
            broadcasted_mask = bool_mask.repeat(out_channels).view(
                out_channels, bool_mask.size(0))
            new_weight_shape = list(node_module.weight.shape)
            new_weight_shape[1] = new_num_channels

            node_module.in_channels = new_num_channels
            node_module.weight = torch.nn.Parameter(
                node_module.weight[broadcasted_mask].view(new_weight_shape))

        nncf_logger.info(
            'Pruned Convolution {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels, new_num_channels))
示例#13
0
    def apply_minmax_init(self,
                          min_values,
                          max_values,
                          log_module_name: str = None):
        if self.initialized:
            nncf_logger.debug(
                "Skipped initializing {} - loaded from checkpoint".format(
                    log_module_name))
            return
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        sign = torch.any(torch.lt(min_values, 0))
        if self.signedness_to_force is not None and sign != self.signedness_to_force:
            nncf_logger.warning("Forcing signed to {} for module {}".format(
                self.signedness_to_force, log_module_name))
            sign = self.signedness_to_force
        self.signed = int(sign)

        abs_max = torch.max(torch.abs(max_values), torch.abs(min_values))
        SCALE_LOWER_THRESHOLD = 0.1
        self.scale.fill_(SCALE_LOWER_THRESHOLD)
        self.scale.masked_scatter_(torch.gt(abs_max, SCALE_LOWER_THRESHOLD),
                                   abs_max)

        nncf_logger.info("Set sign: {} and scale: {} for {}".format(
            self.signed, get_flat_tensor_contents_string(self.scale),
            log_module_name))
示例#14
0
 def dump_metric(self, configuration_metric: List[Tensor]):
     import matplotlib.pyplot as plt
     list_to_plot = [cm.item() for cm in configuration_metric]
     fig = plt.figure()
     fig.suptitle('Pareto Frontier')
     ax = fig.add_subplot(2, 1, 1)
     ax.set_yscale('log')
     ax.set_xlabel('Model Size (MB)')
     ax.set_ylabel('Metric value (total perturbation)')
     ax.scatter(self._model_sizes,
                list_to_plot,
                s=20,
                facecolors='none',
                edgecolors='r')
     cm = torch.Tensor(configuration_metric)
     cm_m = cm.median().item()
     configuration_index = configuration_metric.index(cm_m)
     ms_m = self._model_sizes[configuration_index]
     ax.scatter(ms_m,
                cm_m,
                s=30,
                facecolors='none',
                edgecolors='b',
                label='median from all metrics')
     ax.legend()
     plt.savefig(os.path.join(self._dump_dir, 'Pareto_Frontier'))
     nncf_logger.info(
         'Distribution of HAWQ metrics: min_value={:.3f}, max_value={:.3f}, median_value={:.3f}, '
         'median_index={}, total_number={}'.format(
             cm.min().item(),
             cm.max().item(), cm_m, configuration_index,
             len(configuration_metric)))
示例#15
0
    def apply_minmax_init(self,
                          min_values,
                          max_values,
                          log_module_name: str = None):
        if self.initialized:
            nncf_logger.debug(
                "Skipped initializing {} - loaded from checkpoint".format(
                    log_module_name))
            return
        if torch.any(torch.eq(min_values, np.inf)) or torch.any(
                torch.eq(max_values, -np.inf)):
            raise AttributeError(
                'Statistics is not collected for {}'.format(log_module_name))
        ranges = max_values - min_values
        max_range = torch.max(max_values - min_values)
        eps = 1e-2
        correction = (clamp(ranges, low=eps * max_range, high=max_range) -
                      ranges) * 0.5
        self.input_range.data = (ranges + 2 * correction).data
        self.input_low.data = (min_values - correction).data

        nncf_logger.info("Set input_low: {} and input_range: {} for {}".format(
            get_flat_tensor_contents_string(self.input_low),
            get_flat_tensor_contents_string(self.input_range),
            log_module_name))
示例#16
0
    def propagate_can_prune_attr_up(self):
        """
        Propagating can_prune attribute in reversed topological order.
        This attribute depends on accept_pruned_input and can_prune attributes of output nodes.
        Node can_prune is True if all outputs accept_pruned_input is True and all outputs
        (except convs because conv can be pruned by input and output independently).
        """
        for node_name in self.nx_graph.nodes:
            self.nx_graph.nodes[node_name][ModelPruner.CAN_PRUNE_ATTR] = True

        reversed_sorted_nodes = reversed([
            self.nx_graph.nodes[name]
            for name in nx.topological_sort(self.nx_graph)
        ])
        for node in reversed_sorted_nodes:
            # Check all output nodes accept_pruned_input attribute
            out_edges = self.nx_graph.out_edges(node['key'])
            outputs_accept_pruned_input = all(
                self.nx_graph.nodes[key]['accept_pruned_input']
                for _, key in out_edges)

            # Check all output nodes can_prune attribute
            outputs_will_be_pruned = all([
                self.nx_graph.nodes[key][ModelPruner.CAN_PRUNE_ATTR]
                for _, key in out_edges
                if self.node_propagate_can_prune_attr(key)
            ])
            node[
                ModelPruner.
                CAN_PRUNE_ATTR] = outputs_accept_pruned_input and outputs_will_be_pruned

        nncf_logger.info('Propagated can_prune attribute up')
示例#17
0
    def _compute_and_display_flops_binarization_rate(self):
        net = self._model
        weight_list = {}
        state_dict = net.state_dict()
        for n, v in state_dict.items():
            weight_list[n] = v.clone()

        ops_dict = OrderedDict()

        def get_hook(name):
            def compute_flops_hook(self, input_, output):
                name_type = str(type(self).__name__)
                if isinstance(self, (nn.Conv2d, nn.ConvTranspose2d)):
                    ks = self.weight.data.shape
                    ops_count = ks[0] * ks[1] * ks[2] * ks[3] * output.shape[3] * output.shape[2]
                elif isinstance(self, nn.Linear):
                    ops_count = input_[0].shape[1] * output.shape[1]
                else:
                    return
                ops_dict[name] = (name_type, ops_count, isinstance(self, NNCFConv2d))

            return compute_flops_hook

        hook_list = [m.register_forward_hook(get_hook(n)) for n, m in net.named_modules()]

        net.do_dummy_forward(force_eval=True)

        for h in hook_list:
            h.remove()

        # restore all parameters that can be corrupted due forward pass
        for n, v in state_dict.items():
            state_dict[n].data.copy_(weight_list[n].data)

        ops_bin = 0
        ops_total = 0

        for layer_name, (layer_type, ops, is_binarized) in ops_dict.items():
            ops_total += ops
            if is_binarized:
                ops_bin += ops

        table = Texttable()
        header = ["Layer name", "Layer type", "Binarized", "MAC count", "MAC share"]
        table_data = [header]

        for layer_name, (layer_type, ops, is_binarized) in ops_dict.items():
            drow = {h: 0 for h in header}
            drow["Layer name"] = layer_name
            drow["Layer type"] = layer_type
            drow["Binarized"] = 'Y' if is_binarized else 'N'
            drow["MAC count"] = "{:.3f}G".format(ops * 1e-9)
            drow["MAC share"] = "{:2.1f}%".format(ops / ops_total * 100)
            row = [drow[h] for h in header]
            table_data.append(row)

        table.add_rows(table_data)
        nncf_logger.info(table.draw())
        nncf_logger.info("Total binarized MAC share: {:.1f}%".format(ops_bin / ops_total * 100))
示例#18
0
    def _run_quantization_pipeline(self, finetune=False) -> float:
        self.qctrl.run_batchnorm_adaptation(self.qctrl.quantization_config)

        if finetune:
            raise NotImplementedError(
                "Post-Quantization fine tuning is not implemented.")
        with torch.no_grad():
            quantized_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info(
                "[Q.Env] Quantized Score: {:.3f}".format(quantized_score))
        return quantized_score
示例#19
0
 def prune_model(self):
     """
     Model pruner work in two stages:
     1. Mask propagation: propagate pruning masks through the graph.
     2. Applying calculated masks
     :return:
     """
     nncf_logger.info('Start pruning model')
     self.mask_propagation()
     self.apply_mask()
     nncf_logger.info('Finished pruning model')
示例#20
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules in scope: {}".format(
                            child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    if isinstance(model, nn.Sequential):
                        # pylint: disable=protected-access
                        model._modules[name] = replaced_module
                    else:
                        setattr(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
        _, affected_scopes = replace_modules(module, replace_fn,
                                             affected_scopes, ignored_scopes,
                                             target_scopes, memo, child_scope)
    return model, affected_scopes
示例#21
0
    def apply_init(self):
        original_device = next(self._model.parameters()).device
        self._model.to(self._init_device)

        traces_per_layer = self._calc_traces(self._criterion,
                                             self._iter_number,
                                             self._tolerance)
        if not traces_per_layer:
            raise RuntimeError('Failed to calculate hessian traces!')

        num_weights = len(self._ordered_weight_quantizations)
        bits_configurations = self.get_configs_constrained_by_order(
            self._bits, num_weights)
        ordered_weight_quantization_ids = list(
            self._ordered_weight_quantizations.keys())
        bits_configurations = self._filter_configs_by_precision_constraints(
            bits_configurations, self._hw_precision_constraints,
            ordered_weight_quantization_ids,
            traces_per_layer.get_order_of_traces())
        if not bits_configurations:
            raise RuntimeError(
                'All bits configurations are incompatible with HW Config!')

        perturbations, weight_observers = self.calc_quantization_noise()

        configuration_metric = self.calc_hawq_metric_per_configuration(
            bits_configurations, perturbations, traces_per_layer,
            self._init_device)

        chosen_config_per_layer = self.choose_configuration(
            configuration_metric, bits_configurations,
            traces_per_layer.get_order_of_traces())
        self.set_chosen_config(chosen_config_per_layer)
        ordered_metric_per_layer = self.get_metric_per_layer(
            chosen_config_per_layer, perturbations, traces_per_layer)
        if is_debug():
            hawq_debugger = HAWQDebugger(bits_configurations, perturbations,
                                         weight_observers, traces_per_layer,
                                         self._bits)
            hawq_debugger.dump_metric(configuration_metric)
            hawq_debugger.dump_avg_traces()
            hawq_debugger.dump_density_of_quantization_noise()
            hawq_debugger.dump_perturbations_ratio()
            hawq_debugger.dump_bitwidth_graph(self._algo, self._model)

        self._model.rebuild_graph()
        str_bw = [str(element) for element in self.get_bitwidth_per_scope()]
        nncf_logger.info('\n'.join(
            ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))

        self._model.to(original_device)
        return ordered_metric_per_layer
示例#22
0
    def _evaluate_pretrained_model(self):
        logger.info("[Q.Env] Evaluating Pretrained Model")
        self.qctrl.disable_weight_quantization()
        self.qctrl.disable_activation_quantization()

        with torch.no_grad():
            self.pretrained_score = self.eval_fn(self.qmodel, self.eval_loader)
            logger.info("Pretrained Score: {:.3f}".format(
                self.pretrained_score))

        self.qctrl.enable_weight_quantization()
        self.qctrl.enable_activation_quantization()
        self.qmodel.rebuild_graph()
 def choose_configuration(self, configuration_metric: List[Tensor], bits_configurations: List[List[int]],
                          traces_order: List[int]) -> List[int]:
     num_weights = len(traces_order)
     ordered_config = [0] * num_weights
     median_metric = torch.Tensor(configuration_metric).to(self._device).median()
     configuration_index = configuration_metric.index(median_metric)
     bit_configuration = bits_configurations[configuration_index]
     for i, bitwidth in enumerate(bit_configuration):
         ordered_config[traces_order[i]] = bitwidth
     if is_main_process():
         nncf_logger.info('Chosen HAWQ configuration (bitwidth per weightable layer)={}'.format(ordered_config))
         nncf_logger.debug('Order of the weightable layers in the HAWQ configuration={}'.format(traces_order))
     return ordered_config
示例#24
0
    def apply_init(self):
        disabled_gradients = self.disable_quantizer_gradients(
            self._all_quantizers_per_scope,
            self._algo.quantized_weight_modules_registry, self._model)

        traces_per_layer = self._calc_traces(self._criterion,
                                             self._iter_number,
                                             self._tolerance)
        if not traces_per_layer:
            raise RuntimeError('Failed to calculate hessian traces!')

        self.enable_quantizer_gradients(self._model,
                                        self._all_quantizers_per_scope,
                                        disabled_gradients)

        num_weights = len(self._ordered_weight_quantizations)
        bits_configurations = self.get_configs_constrained_by_order(
            self._bits, num_weights)
        ordered_weight_quantization_ids = list(
            self._ordered_weight_quantizations.keys())
        bits_configurations = self.filter_configs_by_precision_constraints(
            bits_configurations, self._hw_precision_constraints,
            ordered_weight_quantization_ids,
            traces_per_layer.get_order_of_traces())
        if not bits_configurations:
            raise RuntimeError(
                'All bits configurations are incompatible with HW Config!')

        perturbations, weight_observers = self.calc_quantization_noise()

        configuration_metric = self.calc_hawq_metric_per_configuration(
            bits_configurations, perturbations, traces_per_layer, self._device)

        chosen_config_per_layer = self.choose_configuration(
            configuration_metric, bits_configurations,
            traces_per_layer.get_order_of_traces())
        self.set_chosen_config(chosen_config_per_layer)
        ordered_metric_per_layer = self.get_metric_per_layer(
            chosen_config_per_layer, perturbations, traces_per_layer)
        if is_debug():
            self.HAWQDump(bits_configurations, configuration_metric,
                          perturbations, weight_observers, traces_per_layer,
                          self._bits).run()

        self._model.rebuild_graph()
        str_bw = [str(element) for element in self.get_bitwidth_per_scope()]
        nncf_logger.info('\n'.join(
            ['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))

        return ordered_metric_per_layer
    def binarize(self, x):
        if self.training and not self.is_scale_initialized:
            # init scale using nonbinarized activation statistics
            d = x.detach().data.contiguous().view(-1)
            top_num = max(1, round(d.shape[0]*0.001))
            topk_res = d.topk(top_num)
            scale = topk_res[0].min()
            nncf_logger.info("Binarized activation scale set to: {}".format(scale.item()))
            self.scale.data[:] = scale.log()
            self.is_scale_initialized = True

        x = self.bin(x, self.scale.exp(), self.threshold.sigmoid())

        return x
示例#26
0
 def mask_propagation(self):
     """
     Mask propagation in graph:
     to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order.
     """
     sorted_nodes = [
         self.nx_graph.nodes[node_name]
         for node_name in nx.topological_sort(self.nx_graph)
     ]
     for node in sorted_nodes:
         node_type = self.graph.node_type_fn(node)
         cls = self.get_class_by_type_name(node_type)()
         cls.mask_propagation(self.model, node, self.graph, self.nx_graph)
     nncf_logger.info('Finished mask propagation in graph')
 def apply_init(self):
     runner = HessianAwarePrecisionInitializeRunner(self._algo, self._model, self._data_loader,
                                                    self._num_data_points,
                                                    self._all_quantizations, self._ordered_weight_quantizations,
                                                    self._bits, self._traces_per_layer_path)
     runner.run(self._criterion, self._iter_number, self._tolerance)
     self._model.rebuild_graph()
     if self._is_distributed:
         # NOTE: Order of quantization modules must be the same on GPUs to correctly broadcast num_bits
         sorted_quantizers = OrderedDict(sorted(self._all_quantizations.items(), key=lambda x: str(x[0])))
         for quantizer in sorted_quantizers.values():  # type: BaseQuantizer
             quantizer.broadcast_num_bits()
         if is_main_process():
             str_bw = [str(element) for element in self.get_bitwidth_per_scope(sorted_quantizers)]
             nncf_logger.info('\n'.join(['\n\"bitwidth_per_scope\": [', ',\n'.join(str_bw), ']']))
示例#28
0
 def prune_model(self):
     """
     Model pruner work in three stages:
     1. Mask propagation: propagate pruning masks through the graph.
     2. Propagate can_prune attribute (up and then down) through the graph.This attribute shows can we really
     prune some node or not.
     3. Applying masks accordingly with can_prune attribute (only when can prune).
     :return:
     """
     nncf_logger.info('Start pruning model')
     self.mask_propagation()
     self.propagate_can_prune_attr_up()
     self.propagate_can_prune_attr_down()
     self.apply_mask()
     nncf_logger.info('Finished pruning model')
示例#29
0
    def _apply_minmax_init(self,
                           min_values,
                           max_values,
                           log_module_name: str = None):
        ranges = max_values - min_values
        max_range = torch.max(max_values - min_values)
        eps = 1e-2
        correction = (clamp(ranges, low=eps * max_range, high=max_range) -
                      ranges) * 0.5
        self._input_range_param_storage.data = (ranges + 2 * correction).data
        if self._is_using_log_scale_storage:
            self._input_range_param_storage.data.log_()

        self.input_low.data = (min_values - correction).data

        nncf_logger.info("Set input_low: {} and input_range: {} for {}".format(
            get_flat_tensor_contents_string(self.input_low),
            get_flat_tensor_contents_string(self.input_range),
            log_module_name))
示例#30
0
    def input_prune(self, model: NNCFNetwork, nx_node, graph: NNCFGraph,
                    nx_graph: nx.DiGraph):
        input_mask = nx_node['input_masks'][0]
        if input_mask is None:
            return
        bool_mask = torch.tensor(input_mask, dtype=torch.bool)

        nncf_node = graph._nx_node_to_nncf_node(nx_node)
        node_module = model.get_module_by_scope(
            nncf_node.op_exec_context.scope_in_model)
        old_num_clannels = int(node_module.weight.size(0))

        node_module.in_channels = int(torch.sum(bool_mask))
        node_module.weight = torch.nn.Parameter(node_module.weight[bool_mask])

        nncf_logger.info(
            'Pruned ConvTranspose {} by input mask. Old input filters number: {}, new filters number:'
            ' {}.'.format(nx_node['key'], old_num_clannels,
                          node_module.in_channels))