Beispiel #1
0
def get_all_modules_by_type(model,
                            module_types=None,
                            current_scope=None,
                            ignored_scopes=None,
                            target_scopes=None) -> Dict['Scope', Module]:
    if isinstance(module_types, str):
        module_types = [module_types]
    found = OrderedDict()
    from nncf.torch.dynamic_graph.scope import Scope
    from nncf.torch.dynamic_graph.scope import ScopeElement
    if current_scope is None:
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))
    for name, module in model.named_children():
        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)

        if matches_any(str(child_scope), ignored_scopes):
            continue

        if target_scopes is None or matches_any(str(child_scope),
                                                target_scopes):
            if module_types is None or module_types.count(
                    str(type(module).__name__)) != 0:
                found[child_scope] = module
            sub_found = get_all_modules_by_type(module,
                                                module_types,
                                                current_scope=child_scope,
                                                ignored_scopes=ignored_scopes,
                                                target_scopes=target_scopes)
            if sub_found:
                found.update(sub_found)
    return found
def replace_modules(model: nn.Module, replace_fn, affected_scopes, ignored_scopes=None, target_scopes=None, memo=None,
                    current_scope=None, eval_op_scopes: List[Scope] = None, reset: bool = False):
    if memo is None:
        memo = set()
        current_scope = Scope()
        current_scope.push(ScopeElement(model.__class__.__name__))

    if model in memo:
        return model, affected_scopes

    memo.add(model)
    for name, module in model.named_children():
        if module is None:
            continue

        child_scope_element = ScopeElement(module.__class__.__name__, name)
        child_scope = current_scope.copy()
        child_scope.push(child_scope_element)
        replaced_module = replace_fn(module)

        if replaced_module is not None:
            replaced_scope_element = ScopeElement(replaced_module.__class__.__name__, name)
            replaced_scope = current_scope.copy()
            replaced_scope.push(replaced_scope_element)
            if module is not replaced_module:
                if matches_any(str(child_scope), ignored_scopes):
                    nncf_logger.info("Ignored wrapping modules specified in scope: {}".format(child_scope))
                    continue
                if eval_op_scopes is None:
                    eval_op_scopes = []
                is_ignored = True
                for eval_op_scope in eval_op_scopes:
                    # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode
                    # inside it
                    if eval_op_scope in child_scope:
                        is_ignored = False
                        break
                if is_ignored and eval_op_scopes:
                    nncf_logger.info(
                        "Ignored wrapping modules not called in eval mode in scope: {}".format(child_scope))
                    continue

                if target_scopes is None or matches_any(str(child_scope), target_scopes):
                    nncf_logger.info("Wrapping module {} by {}".format(str(child_scope),
                                                                       str(replaced_scope)))
                    set_replaced_module_by_name(model, name, replaced_module)
                    affected_scopes.append(replaced_scope)
            elif is_nncf_module(replaced_module):
                # Got an NNCF-wrapped module from previous compression stage, track its scope as well
                affected_scopes.append(replaced_scope)
                if reset:
                    replaced_module.reset()
        _, affected_scopes = replace_modules(module, replace_fn, affected_scopes, ignored_scopes, target_scopes,
                                             memo, child_scope, eval_op_scopes, reset=reset)
    return model, affected_scopes