示例#1
0
def get_all_modules_by_type(model,
                            module_types,
                            current_scope=None,
                            ignored_scopes=None,
                            target_scopes=None) -> Dict['Scope', Module]:
    if isinstance(module_types, str):
        module_types = [module_types]
    found = OrderedDict()
    from nncf.dynamic_graph.context import Scope
    from nncf.dynamic_graph.context import ScopeElement
    if current_scope is None:
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))
    for name, module in model.named_children():
        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)

        if in_scope_list(str(child_scope), ignored_scopes):
            continue

        if target_scopes is None or in_scope_list(str(child_scope),
                                                  target_scopes):
            if module_types.count(str(type(module).__name__)) != 0:
                found[child_scope] = module
            sub_found = get_all_modules_by_type(module,
                                                module_types,
                                                current_scope=child_scope,
                                                ignored_scopes=ignored_scopes,
                                                target_scopes=target_scopes)
            if sub_found:
                found.update(sub_found)
    return found
示例#2
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules in scope: {}".format(
                            child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    if isinstance(model, nn.Sequential):
                        # pylint: disable=protected-access
                        model._modules[name] = replaced_module
                    else:
                        setattr(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
        _, affected_scopes = replace_modules(module, replace_fn,
                                             affected_scopes, ignored_scopes,
                                             target_scopes, memo, child_scope)
    return model, affected_scopes
示例#3
0
def replace_modules(model: nn.Module,
                    replace_fn,
                    affected_scopes,
                    ignored_scopes=None,
                    target_scopes=None,
                    memo=None,
                    current_scope=None,
                    eval_ops_exec_ctx_str: List[str] = None,
                    reset: bool = False):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(
                replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if in_scope_list(str(child_scope), ignored_scopes):
                    nncf_logger.info(
                        "Ignored wrapping modules specified in scope: {}".
                        format(child_scope))
                    continue
                if eval_ops_exec_ctx_str is None:
                    eval_ops_exec_ctx_str = []
                is_ignored = True
                for op_ctx_str in eval_ops_exec_ctx_str:
                    full_op_scope = Scope.from_str(op_ctx_str)
                    # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode
                    # inside it
                    if full_op_scope in child_scope:
                        is_ignored = False
                        break
                if is_ignored and eval_ops_exec_ctx_str:
                    nncf_logger.info(
                        "Ignored wrapping modules not called in eval mode in scope: {}"
                        .format(child_scope))
                    continue

                if target_scopes is None or in_scope_list(
                        str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(
                        str(child_scope), str(replaced_scope)))
                    set_replaced_module_by_name(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
                if reset:
                    replaced_module.reset()
        _, affected_scopes = replace_modules(module,
                                             replace_fn,
                                             affected_scopes,
                                             ignored_scopes,
                                             target_scopes,
                                             memo,
                                             child_scope,
                                             eval_ops_exec_ctx_str,
                                             reset=reset)
    return model, affected_scopes