def __init__( self, model: nn.Module, inputs: Union[Tensor, Tuple[Tensor, ...]], ) -> None: """ Args: model: The model to analyze inputs: The inputs to the model for analysis. We will trace the execution of `model.forward(inputs)`. This means inputs have to be tensors or tuple of tensors (see https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace). In order to trace other methods or unsupported input types, you may need to implement a wrapper module. """ self._model = model self._inputs = inputs self._op_handles: Dict[str, Handle] = {} # Mapping from names to submodules self._named_modules: Dict[str, nn.Module] = dict( _named_modules_with_dup(model)) # Mapping from submodules and their aliases to the canonical name of each submodule self._aliases: Dict[Union[nn.Module, str], str] = self._get_aliases(model) self._stats: Optional[Statistics] = None self._ignored_ops: Set[str] = copy(_IGNORED_OPS) self.unsupported_ops_warnings(True) self.uncalled_modules_warnings(True) self.tracer_warnings("no_tracer_warning") self.ancestor_mode("owner")
def _analyze(self) -> "Statistics": # Don't calculate if results are already stored. stats = self._stats if stats is not None: return stats with warnings.catch_warnings(): if self._warn_trace == "none": warnings.simplefilter("ignore") elif self._warn_trace == "no_tracer_warning": warnings.filterwarnings("ignore", category=TracerWarning) graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) # Assures even modules not in the trace graph are initialized to zero count counts = {} unsupported_ops = {} # We don't need the duplication here, but self._model.named_modules() # gives slightly different results for some wrapped models. for _, mod in _named_modules_with_dup(self._model): name = self._aliases[mod] counts[name] = Counter() unsupported_ops[name] = Counter() all_seen = set() for node in graph.nodes(): kind = node.kind() scope_names = node.scopeName().split("/") all_seen.update(scope_names) if self._ancestor_mode == "caller": ancestors = set(scope_names) else: ancestors = self._get_all_ancestors(scope_names[-1]) all_seen.update(ancestors) if kind not in self._op_handles: # ignore all prim:: operators if kind in self._ignored_ops or kind.startswith("prim::"): continue for name in ancestors: unsupported_ops[name][kind] += 1 else: inputs, outputs = list(node.inputs()), list(node.outputs()) op_counts = self._op_handles[kind](inputs, outputs) if isinstance(op_counts, Number): op_counts = Counter( {self._simplify_op_name(kind): op_counts}) # Assures an op contributes at most once to a module for name in ancestors: counts[name] += op_counts uncalled_mods = set(self._aliases.values()) - all_seen stats = Statistics(counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods) self._stats = stats self._warn_unsupported_ops(unsupported_ops[""]) self._warn_uncalled_mods(uncalled_mods) return stats
def _get_aliases(self, model: nn.Module) -> Dict[Union[str, nn.Module], str]: aliases = {} for name, module in _named_modules_with_dup(model): if module not in aliases: aliases[module] = name aliases[name] = aliases[module] return aliases
def __init__( self, model: nn.Module, inputs: Union[Tensor, Tuple[Tensor, ...]], ) -> None: """ Args: model: The model to analyze inputs: The inputs to the model for analysis. """ self._model = model self._inputs = inputs self._op_handles: Dict[str, Handle] = {} # Mapping from names to submodules self._named_modules: Dict[str, nn.Module] = dict( _named_modules_with_dup(model)) # Mapping from submodules and their aliases to the canonical name of each submodule self._aliases: Dict[Union[nn.Module, str], str] = self._get_aliases(model) self._stats: Optional[Statistics] = None self._ignored_ops: Set[str] = copy(_IGNORED_OPS) self.unsupported_ops_warnings(True) self.uncalled_modules_warnings(True) self.tracer_warnings("no_tracer_warning") self.ancestor_mode("owner")
def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Module]]: """ Like .named_modules(), but the results are slightly different for some wrapped models. """ seen = set() for name, mod in _named_modules_with_dup(model): if mod not in seen: seen.add(mod) yield name, mod
def _get_scoped_trace_graph( module: nn.Module, inputs: Tuple[object, ...], aliases: Dict[Union[str, nn.Module], str], ) -> torch._C.Graph: # pyre-ignore[11] """ Traces the provided module using torch.jit._get_trace_graph, but adds submodule scope information to each graph node. The resulting graph is in-lined and has all model parameters treated as inputs. The input model has the scope name '', while its descendants have names of the form 'child.grandchild.grandgrandchild...'. Args: model (nn.Module) : The module to trace inputs (tuple) : Inputs used during the trace of the model aliases (dict(str or nn.Module, str) : maps modules and module names to the canonical name to be used as the scope for that module. Returns: graph (torch._C.Graph) : The pytorch JIT trace of the model """ class ScopePushHook(object): def __init__(self, name: str) -> None: self.name = name def __call__(self, module: nn.Module, inputs: Tuple[object, ...]) -> Tuple[object, ...]: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.push_scope(self.name) return inputs class ScopePopHook(object): def __call__( self, module: nn.Module, inputs: Tuple[object, ...], outputs: Tuple[object, ...], ) -> Tuple[object, ...]: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.pop_scope() return outputs seen = set() hook_handles = [] # type: List[Any] def register_hooks(mod: nn.Module, name: str) -> None: prehook = mod.register_forward_pre_hook( ScopePushHook(name)) # pyre-ignore[16] posthook = mod.register_forward_hook(ScopePopHook()) # pyre-ignore[16] hook_handles.append(prehook) hook_handles.append(posthook) # Torch script does not support parallel torch models, but we still # want the scope names to be correct for the complete module. if isinstance( module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): # Since DataParallel just wraps the model, add an extra set of hooks # to the model it wraps to account for the wrapper. Then trace it. root_name = aliases[module] module = module.module register_hooks(module, root_name) # We don't need the duplication here, but self._model.named_modules() # gives slightly different results for some wrapped models. for name, mod in _named_modules_with_dup(module): if mod not in seen: name = aliases[mod] register_hooks(mod, name) seen.add(mod) if hasattr(torch.jit, "get_trace_graph"): trace, _ = torch.jit.get_trace_graph(module, inputs) graph = trace.graph() else: graph, _ = _get_trace_graph(module, inputs) for handle in hook_handles: handle.remove() return graph
def _analyze(self) -> "Statistics": # Don't calculate if results are already stored. stats = self._stats if stats is not None: return stats with warnings.catch_warnings(): if self._warn_trace == "none": warnings.simplefilter("ignore") elif self._warn_trace == "no_tracer_warning": warnings.filterwarnings("ignore", category=TracerWarning) graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) # Assures even modules not in the trace graph are initialized to zero count counts = {} unsupported_ops = {} # We don't need the duplication here, but self._model.named_modules() # gives slightly different results for some wrapped models. for _, mod in _named_modules_with_dup(self._model): name = self._aliases[mod] counts[name] = Counter() unsupported_ops[name] = Counter() all_seen = set() for node in graph.nodes(): kind = node.kind() if kind == "prim::PythonOp": # for PythonOp, pyname contains the actual name in Python kind = kind + "." + node.pyname() scope_names = node.scopeName().split("/") all_seen.update(scope_names) if self._ancestor_mode == "caller": ancestors = set(scope_names) else: ancestors = self._get_all_ancestors(scope_names[-1]) all_seen.update(ancestors) if kind not in self._op_handles: if self._should_ignore_node(node): continue for name in ancestors: unsupported_ops[name][kind] += 1 else: inputs, outputs = list(node.inputs()), list(node.outputs()) op_counts = self._op_handles[kind](inputs, outputs) if isinstance(op_counts, Number): op_counts = Counter({self._simplify_op_name(kind): op_counts}) for v in op_counts.values(): if not isinstance(v, (int, float, np.float64, np.int64)): raise ValueError( f"Invalid type {type(v)} for the flop count! " "Please use a wider type to avoid overflow." ) # Assures an op contributes at most once to a module for name in ancestors: counts[name] += op_counts uncalled_mods = set(self._aliases.values()) - all_seen stats = Statistics( counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods ) self._stats = stats self._warn_unsupported_ops(unsupported_ops[""]) self._warn_uncalled_mods(uncalled_mods) return stats
def _analyze(self) -> "Statistics": # Don't calculate if results are already stored. stats = self._stats if stats is not None: return stats with warnings.catch_warnings(): if self._warn_trace == "none": warnings.simplefilter("ignore") elif self._warn_trace == "no_tracer_warning": warnings.filterwarnings("ignore", category=TracerWarning) graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) # Assures even modules not in the trace graph are initialized to zero count counts = {} unsupported_ops = {} # We don't need the duplication here, but self._model.named_modules() # gives slightly different results for some wrapped models. for _, mod in _named_modules_with_dup(self._model): name = self._aliases[mod] counts[name] = Counter() unsupported_ops[name] = Counter() all_seen = set() for node in graph.nodes(): kind = node.kind() scope_names = node.scopeName().split("/") all_seen.update(scope_names) if kind not in self._op_handles: # ignore all prim:: operators if kind in self._ignored_ops or kind.startswith("prim::"): continue for name in set(scope_names): unsupported_ops[name][kind] += 1 else: inputs, outputs = list(node.inputs()), list(node.outputs()) op_counts = self._op_handles[kind](inputs, outputs) # Assures an op contributes at most once to a module for name in set(scope_names): counts[name] += op_counts uncalled_mods = set(self._aliases.values()) - all_seen def has_forward(module_type) -> bool: # Containers are not meant to be called anyway (they don't have forward) no_forward_mods = {nn.ModuleList, nn.ModuleDict, nn.Module} for mod in no_forward_mods: if module_type.forward is mod.forward: return False return True uncalled_mods = { m for m in uncalled_mods if has_forward(type(self._named_modules.get(m))) } stats = Statistics(counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods) self._stats = stats self._warn_unsupported_ops(unsupported_ops[""]) self._warn_uncalled_mods(uncalled_mods) return stats