Beispiel #1
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules in scope: {}".format(
                            child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    if isinstance(model, nn.Sequential):
                        # pylint: disable=protected-access
                        model._modules[name] = replaced_module
                    else:
                        setattr(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
        _, affected_scopes = replace_modules(module, replace_fn,
                                             affected_scopes, ignored_scopes,
                                             target_scopes, memo, child_scope)
    return model, affected_scopes
Beispiel #2
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    prefix=None,
                    logger=None):
    if memo is None:
        memo = set()
        prefix = model.__class__.__name__

    if model not in memo:
        memo.add(model)
        for name, module in model.named_children():
            if module is None:
                continue

            child_name = get_node_name(module, name, prefix)
            replaced_module = replace_fn(module)

            if replaced_module is not None and module is not replaced_module:
                if in_scope_list(child_name, ignored_scopes):
                    if logger is not None:
                        logger.info(
                            "Ignored wrapping modules in scope: {}".format(
                                child_name))
                    continue

                if target_scopes is None or in_scope_list(
                        child_name, target_scopes):
                    if logger is not None:
                        logger.info("Wrapping module {} by {}".format(
                            child_name,
                            get_node_name(replaced_module, name, prefix)))
                    if isinstance(model, nn.Sequential):
                        # pylint: disable=protected-access
                        model._modules[name] = replaced_module
                    else:
                        setattr(model, name, replaced_module)

            replace_modules(module, replace_fn, ignored_scopes, target_scopes,
                            memo, child_name, logger)
    return model
 def _register_weight_sparsifying_operations(self, device, ignored_scopes,
                                             target_scopes, logger):
     sparsified_modules = get_all_modules_by_type(self._model, NNCF_MODULES)
     self.sparsified_module_info = []
     for module_name, module in sparsified_modules.items():
         if in_scope_list(module_name, ignored_scopes):
             logger.info(
                 "Ignored adding Weight Sparsifier in scope: {}".format(
                     module_name))
             continue
         if target_scopes is None or in_scope_list(module_name,
                                                   target_scopes):
             logger.info("Adding Weight Sparsifier in scope: {}".format(
                 module_name))
             operation = self.create_weight_sparsifying_operation(module)
             opid = module.register_pre_forward_operation(
                 UpdateWeight(operation).to(device))
             self.sparsified_module_info.append(
                 SparseModuleInfo(module_name, module,
                                  module.get_pre_op(opid).operand))
 def apply_init(self):
     for pair in self._bitwidth_per_scope:
         if len(pair) != 2:
             raise ValueError('Invalid format of bitwidth per scope: [int, str] is expected')
         bitwidth = pair[0]
         scope_name = pair[1]
         is_matched = False
         for scope, quantizer in self._all_quantizations.items():
             if in_scope_list(str(scope), scope_name):
                 quantizer.num_bits = bitwidth
                 is_matched = True
         if not is_matched:
             raise ValueError(
                 'Invalid scope name `{}`, failed to assign bitwidth {} to it'.format(scope_name, bitwidth))
Beispiel #5
0
 def _should_consider_scope(self, scope_str: str) -> bool:
     return (self.target_scopes is None or in_scope_list(scope_str, self.target_scopes)) \
            and not in_scope_list(scope_str, self.ignored_scopes)
Beispiel #6
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None,
                    eval_ops_exec_ctx_str: List[str] = None,
                    reset: bool = False):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules specified in scope: {}".
                        format(child_scope))
                    continue
                if eval_ops_exec_ctx_str is None:
                    eval_ops_exec_ctx_str = []
                is_ignored = True
                for op_ctx_str in eval_ops_exec_ctx_str:
                    full_op_scope = Scope.from_str(op_ctx_str)
                    # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode
                    # inside it
                    if full_op_scope in child_scope:
                        is_ignored = False
                        break
                if is_ignored and eval_ops_exec_ctx_str:
                    nncf_logger.info(
                        "Ignored wrapping modules not called in eval mode in scope: {}"
                        .format(child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    set_replaced_module_by_name(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
                if reset:
                    replaced_module.reset()
        _, affected_scopes = replace_modules(module,
                                             replace_fn,
                                             affected_scopes,
                                             ignored_scopes,
                                             target_scopes,
                                             memo,
                                             child_scope,
                                             eval_ops_exec_ctx_str,
                                             reset=reset)
    return model, affected_scopes
Beispiel #7
0
    def __init__(self, ip_graph: InsertionPointGraph, ignored_scopes=None):
        super().__init__()
        ip_graph = deepcopy(ip_graph)
        self._created_prop_quantizer_counter = 0

        self._ignored_scopes = deepcopy(ignored_scopes)
        self.ignored_node_keys = []

        barrier_node_extra_edges = []
        for node_key, node in ip_graph.nodes.items():
            qpg_node = {
                self.NODE_TYPE_NODE_ATTR: \
                    self.ipg_node_type_to_qpsg_node_type(node[InsertionPointGraph.NODE_TYPE_NODE_ATTR])}
            if node[InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.INSERTION_POINT:
                qpg_node[self.PROPAGATING_QUANTIZER_NODE_ATTR] = None
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.INSERTION_POINT_DATA_NODE_ATTR] = node[
                    InsertionPointGraph.INSERTION_POINT_DATA_NODE_ATTR]
            elif node[
                    InsertionPointGraph.
                    NODE_TYPE_NODE_ATTR] == InsertionPointGraphNodeType.OPERATOR:
                qpg_node[
                    self.ALLOWED_INPUT_QUANTIZATION_TYPES_NODE_ATTR] = set()
                qpg_node[
                    self.
                    QUANTIZATION_TRAIT_NODE_ATTR] = QuantizationTrait.NON_QUANTIZABLE
                qpg_node[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
                qpg_node[self.OPERATOR_METATYPE_NODE_ATTR] = node[
                    InsertionPointGraph.OPERATOR_METATYPE_NODE_ATTR]
                scope_node = str(
                    node[InsertionPointGraph.REGULAR_NODE_REF_NODE_ATTR][
                        NNCFGraph.OP_EXEC_CONTEXT_NODE_ATTR].input_agnostic)

                if in_scope_list(scope_node, self._ignored_scopes):
                    self.ignored_node_keys.append(node_key)
                    qpg_node_barrier = {
                        self.NODE_TYPE_NODE_ATTR:
                        QuantizerPropagationStateGraphNodeType.
                        AUXILIARY_BARRIER,
                        'label':
                        QuantizerPropagationStateGraph.BARRIER_NODE_KEY_POSTFIX
                    }
                    barrier_node_key = self.get_barrier_node_key(node_key)
                    self.add_node(barrier_node_key, **qpg_node_barrier)
                    barrier_node_extra_edges.append(
                        (barrier_node_key, node_key))

            self.add_node(node_key, **qpg_node)

        for from_node, to_node, edge_data in ip_graph.edges(data=True):
            edge_data[self.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] = []
            self.add_edge(from_node, to_node, **edge_data)

        for u_node_key, v_node_key in barrier_node_extra_edges:
            edge_attr = {
                QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR:
                []
            }
            next_v_node_key = list(
                self.succ[v_node_key].keys())[0]  # POST HOOK v
            self.add_edge(v_node_key, u_node_key, **edge_attr)
            self.add_edge(u_node_key, next_v_node_key, **edge_attr)
            self.remove_edge(v_node_key, next_v_node_key)