def get_all_modules_by_type(model, module_types=None, current_scope=None, ignored_scopes=None, target_scopes=None) -> Dict['Scope', Module]: if isinstance(module_types, str): module_types = [module_types] found = OrderedDict() from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.dynamic_graph.scope import ScopeElement if current_scope is None: current_scope = Scope() current_scope.push(ScopeElement(model.__class__.__name__)) for name, module in model.named_children(): child_scope_element = ScopeElement(module.__class__.__name__, name) child_scope = current_scope.copy() child_scope.push(child_scope_element) if matches_any(str(child_scope), ignored_scopes): continue if target_scopes is None or matches_any(str(child_scope), target_scopes): if module_types is None or module_types.count( str(type(module).__name__)) != 0: found[child_scope] = module sub_found = get_all_modules_by_type(module, module_types, current_scope=child_scope, ignored_scopes=ignored_scopes, target_scopes=target_scopes) if sub_found: found.update(sub_found) return found
def replace_modules(model: nn.Module, replace_fn, affected_scopes, ignored_scopes=None, target_scopes=None, memo=None, current_scope=None, eval_op_scopes: List[Scope] = None, reset: bool = False): if memo is None: memo = set() current_scope = Scope() current_scope.push(ScopeElement(model.__class__.__name__)) if model in memo: return model, affected_scopes memo.add(model) for name, module in model.named_children(): if module is None: continue child_scope_element = ScopeElement(module.__class__.__name__, name) child_scope = current_scope.copy() child_scope.push(child_scope_element) replaced_module = replace_fn(module) if replaced_module is not None: replaced_scope_element = ScopeElement(replaced_module.__class__.__name__, name) replaced_scope = current_scope.copy() replaced_scope.push(replaced_scope_element) if module is not replaced_module: if matches_any(str(child_scope), ignored_scopes): nncf_logger.info("Ignored wrapping modules specified in scope: {}".format(child_scope)) continue if eval_op_scopes is None: eval_op_scopes = [] is_ignored = True for eval_op_scope in eval_op_scopes: # child_scope isn't ignored, if there's at least a single operation or a module called in eval mode # inside it if eval_op_scope in child_scope: is_ignored = False break if is_ignored and eval_op_scopes: nncf_logger.info( "Ignored wrapping modules not called in eval mode in scope: {}".format(child_scope)) continue if target_scopes is None or matches_any(str(child_scope), target_scopes): nncf_logger.info("Wrapping module {} by {}".format(str(child_scope), str(replaced_scope))) set_replaced_module_by_name(model, name, replaced_module) affected_scopes.append(replaced_scope) elif is_nncf_module(replaced_module): # Got an NNCF-wrapped module from previous compression stage, track its scope as well affected_scopes.append(replaced_scope) if reset: replaced_module.reset() _, affected_scopes = replace_modules(module, replace_fn, affected_scopes, ignored_scopes, target_scopes, memo, child_scope, eval_op_scopes, reset=reset) return model, affected_scopes