Ejemplo n.º 1
0
    def __init__(
        self,
        model: nn.Module,
        inputs: Union[Tensor, Tuple[Tensor, ...]],
    ) -> None:
        """
        Args:
            model: The model to analyze
            inputs: The inputs to the model for analysis.

        We will trace the execution of `model.forward(inputs)`. This means
        inputs have to be tensors or tuple of tensors (see
        https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace).
        In order to trace other methods or unsupported input types, you may need
        to implement a wrapper module.
        """
        self._model = model
        self._inputs = inputs
        self._op_handles: Dict[str, Handle] = {}
        # Mapping from names to submodules
        self._named_modules: Dict[str, nn.Module] = dict(
            _named_modules_with_dup(model))
        # Mapping from submodules and their aliases to the canonical name of each submodule
        self._aliases: Dict[Union[nn.Module, str],
                            str] = self._get_aliases(model)
        self._stats: Optional[Statistics] = None

        self._ignored_ops: Set[str] = copy(_IGNORED_OPS)
        self.unsupported_ops_warnings(True)
        self.uncalled_modules_warnings(True)
        self.tracer_warnings("no_tracer_warning")
        self.ancestor_mode("owner")
Ejemplo n.º 2
0
    def _analyze(self) -> "Statistics":
        # Don't calculate if results are already stored.
        stats = self._stats
        if stats is not None:
            return stats

        with warnings.catch_warnings():
            if self._warn_trace == "none":
                warnings.simplefilter("ignore")
            elif self._warn_trace == "no_tracer_warning":
                warnings.filterwarnings("ignore", category=TracerWarning)
            graph = _get_scoped_trace_graph(self._model, self._inputs,
                                            self._aliases)

        # Assures even modules not in the trace graph are initialized to zero count
        counts = {}
        unsupported_ops = {}
        # We don't need the duplication here, but self._model.named_modules()
        # gives slightly different results for some wrapped models.
        for _, mod in _named_modules_with_dup(self._model):
            name = self._aliases[mod]
            counts[name] = Counter()
            unsupported_ops[name] = Counter()

        all_seen = set()
        for node in graph.nodes():
            kind = node.kind()
            scope_names = node.scopeName().split("/")
            all_seen.update(scope_names)
            if self._ancestor_mode == "caller":
                ancestors = set(scope_names)
            else:
                ancestors = self._get_all_ancestors(scope_names[-1])
                all_seen.update(ancestors)
            if kind not in self._op_handles:
                # ignore all prim:: operators
                if kind in self._ignored_ops or kind.startswith("prim::"):
                    continue

                for name in ancestors:
                    unsupported_ops[name][kind] += 1
            else:
                inputs, outputs = list(node.inputs()), list(node.outputs())
                op_counts = self._op_handles[kind](inputs, outputs)
                if isinstance(op_counts, Number):
                    op_counts = Counter(
                        {self._simplify_op_name(kind): op_counts})

                # Assures an op contributes at most once to a module
                for name in ancestors:
                    counts[name] += op_counts

        uncalled_mods = set(self._aliases.values()) - all_seen
        stats = Statistics(counts=counts,
                           unsupported_ops=unsupported_ops,
                           uncalled_mods=uncalled_mods)
        self._stats = stats
        self._warn_unsupported_ops(unsupported_ops[""])
        self._warn_uncalled_mods(uncalled_mods)
        return stats
Ejemplo n.º 3
0
 def _get_aliases(self, model: nn.Module) -> Dict[Union[str, nn.Module], str]:
     aliases = {}
     for name, module in _named_modules_with_dup(model):
         if module not in aliases:
             aliases[module] = name
         aliases[name] = aliases[module]
     return aliases
Ejemplo n.º 4
0
    def __init__(
        self,
        model: nn.Module,
        inputs: Union[Tensor, Tuple[Tensor, ...]],
    ) -> None:
        """
        Args:
            model: The model to analyze
            inputs: The inputs to the model for analysis.
        """
        self._model = model
        self._inputs = inputs
        self._op_handles: Dict[str, Handle] = {}
        # Mapping from names to submodules
        self._named_modules: Dict[str, nn.Module] = dict(
            _named_modules_with_dup(model))
        # Mapping from submodules and their aliases to the canonical name of each submodule
        self._aliases: Dict[Union[nn.Module, str],
                            str] = self._get_aliases(model)
        self._stats: Optional[Statistics] = None

        self._ignored_ops: Set[str] = copy(_IGNORED_OPS)
        self.unsupported_ops_warnings(True)
        self.uncalled_modules_warnings(True)
        self.tracer_warnings("no_tracer_warning")
        self.ancestor_mode("owner")
Ejemplo n.º 5
0
def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Module]]:
    """
    Like .named_modules(), but the results are slightly different for
    some wrapped models.
    """
    seen = set()
    for name, mod in _named_modules_with_dup(model):
        if mod not in seen:
            seen.add(mod)
            yield name, mod
Ejemplo n.º 6
0
def _get_scoped_trace_graph(
    module: nn.Module,
    inputs: Tuple[object, ...],
    aliases: Dict[Union[str, nn.Module], str],
) -> torch._C.Graph:  # pyre-ignore[11]
    """
    Traces the provided module using torch.jit._get_trace_graph, but adds
    submodule scope information to each graph node. The resulting graph
    is in-lined and has all model parameters treated as inputs. The input
    model has the scope name '', while its descendants have names of the
    form 'child.grandchild.grandgrandchild...'.

    Args:
        model (nn.Module) : The module to trace
        inputs (tuple) : Inputs used during the trace of the model
        aliases (dict(str or nn.Module, str) : maps modules and module
            names to the canonical name to be used as the scope for
            that module.

    Returns:
        graph (torch._C.Graph) : The pytorch JIT trace of the model
    """
    class ScopePushHook(object):
        def __init__(self, name: str) -> None:
            self.name = name

        def __call__(self, module: nn.Module,
                     inputs: Tuple[object, ...]) -> Tuple[object, ...]:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.push_scope(self.name)
            return inputs

    class ScopePopHook(object):
        def __call__(
            self,
            module: nn.Module,
            inputs: Tuple[object, ...],
            outputs: Tuple[object, ...],
        ) -> Tuple[object, ...]:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.pop_scope()
            return outputs

    seen = set()
    hook_handles = []  # type: List[Any]

    def register_hooks(mod: nn.Module, name: str) -> None:
        prehook = mod.register_forward_pre_hook(
            ScopePushHook(name))  # pyre-ignore[16]
        posthook = mod.register_forward_hook(ScopePopHook())  # pyre-ignore[16]
        hook_handles.append(prehook)
        hook_handles.append(posthook)

    # Torch script does not support parallel torch models, but we still
    # want the scope names to be correct for the complete module.
    if isinstance(
            module,
        (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):

        # Since DataParallel just wraps the model, add an extra set of hooks
        # to the model it wraps to account for the wrapper. Then trace it.
        root_name = aliases[module]
        module = module.module
        register_hooks(module, root_name)

    # We don't need the duplication here, but self._model.named_modules()
    # gives slightly different results for some wrapped models.
    for name, mod in _named_modules_with_dup(module):
        if mod not in seen:
            name = aliases[mod]
            register_hooks(mod, name)
            seen.add(mod)

    if hasattr(torch.jit, "get_trace_graph"):
        trace, _ = torch.jit.get_trace_graph(module, inputs)
        graph = trace.graph()
    else:
        graph, _ = _get_trace_graph(module, inputs)

    for handle in hook_handles:
        handle.remove()

    return graph
Ejemplo n.º 7
0
    def _analyze(self) -> "Statistics":
        # Don't calculate if results are already stored.
        stats = self._stats
        if stats is not None:
            return stats

        with warnings.catch_warnings():
            if self._warn_trace == "none":
                warnings.simplefilter("ignore")
            elif self._warn_trace == "no_tracer_warning":
                warnings.filterwarnings("ignore", category=TracerWarning)
            graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases)

        # Assures even modules not in the trace graph are initialized to zero count
        counts = {}
        unsupported_ops = {}
        # We don't need the duplication here, but self._model.named_modules()
        # gives slightly different results for some wrapped models.
        for _, mod in _named_modules_with_dup(self._model):
            name = self._aliases[mod]
            counts[name] = Counter()
            unsupported_ops[name] = Counter()

        all_seen = set()
        for node in graph.nodes():
            kind = node.kind()
            if kind == "prim::PythonOp":
                # for PythonOp, pyname contains the actual name in Python
                kind = kind + "." + node.pyname()
            scope_names = node.scopeName().split("/")
            all_seen.update(scope_names)
            if self._ancestor_mode == "caller":
                ancestors = set(scope_names)
            else:
                ancestors = self._get_all_ancestors(scope_names[-1])
                all_seen.update(ancestors)
            if kind not in self._op_handles:
                if self._should_ignore_node(node):
                    continue
                for name in ancestors:
                    unsupported_ops[name][kind] += 1
            else:
                inputs, outputs = list(node.inputs()), list(node.outputs())
                op_counts = self._op_handles[kind](inputs, outputs)
                if isinstance(op_counts, Number):
                    op_counts = Counter({self._simplify_op_name(kind): op_counts})
                for v in op_counts.values():
                    if not isinstance(v, (int, float, np.float64, np.int64)):
                        raise ValueError(
                            f"Invalid type {type(v)} for the flop count! "
                            "Please use a wider type to avoid overflow."
                        )

                # Assures an op contributes at most once to a module
                for name in ancestors:
                    counts[name] += op_counts

        uncalled_mods = set(self._aliases.values()) - all_seen
        stats = Statistics(
            counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods
        )
        self._stats = stats
        self._warn_unsupported_ops(unsupported_ops[""])
        self._warn_uncalled_mods(uncalled_mods)
        return stats
Ejemplo n.º 8
0
    def _analyze(self) -> "Statistics":
        # Don't calculate if results are already stored.
        stats = self._stats
        if stats is not None:
            return stats

        with warnings.catch_warnings():
            if self._warn_trace == "none":
                warnings.simplefilter("ignore")
            elif self._warn_trace == "no_tracer_warning":
                warnings.filterwarnings("ignore", category=TracerWarning)
            graph = _get_scoped_trace_graph(self._model, self._inputs,
                                            self._aliases)

        # Assures even modules not in the trace graph are initialized to zero count
        counts = {}
        unsupported_ops = {}
        # We don't need the duplication here, but self._model.named_modules()
        # gives slightly different results for some wrapped models.
        for _, mod in _named_modules_with_dup(self._model):
            name = self._aliases[mod]
            counts[name] = Counter()
            unsupported_ops[name] = Counter()

        all_seen = set()
        for node in graph.nodes():
            kind = node.kind()
            scope_names = node.scopeName().split("/")
            all_seen.update(scope_names)
            if kind not in self._op_handles:
                # ignore all prim:: operators
                if kind in self._ignored_ops or kind.startswith("prim::"):
                    continue

                for name in set(scope_names):
                    unsupported_ops[name][kind] += 1
            else:
                inputs, outputs = list(node.inputs()), list(node.outputs())
                op_counts = self._op_handles[kind](inputs, outputs)

                # Assures an op contributes at most once to a module
                for name in set(scope_names):
                    counts[name] += op_counts

        uncalled_mods = set(self._aliases.values()) - all_seen

        def has_forward(module_type) -> bool:
            # Containers are not meant to be called anyway (they don't have forward)
            no_forward_mods = {nn.ModuleList, nn.ModuleDict, nn.Module}
            for mod in no_forward_mods:
                if module_type.forward is mod.forward:
                    return False
            return True

        uncalled_mods = {
            m
            for m in uncalled_mods
            if has_forward(type(self._named_modules.get(m)))
        }

        stats = Statistics(counts=counts,
                           unsupported_ops=unsupported_ops,
                           uncalled_mods=uncalled_mods)
        self._stats = stats
        self._warn_unsupported_ops(unsupported_ops[""])
        self._warn_uncalled_mods(uncalled_mods)
        return stats