Exemplo n.º 1
0
    def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
        self._src_model = model
        self._named_modules = OrderedDict(model.named_modules())
        self._adj_map = None
        self._layers_topological_order = None
        self._top_level_ops = set()
        model_clone = distiller.make_non_parallel_copy(model)

        # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList
        # See documentation of _DistillerModuleList class for details on why this is done
        model_clone, converted_module_names_map = _to_distiller_modulelist(
            model_clone)

        with torch.onnx.set_training(model_clone, False):

            device = distiller.model_device(model_clone)
            dummy_input = distiller.convert_tensors_recursively_to(
                dummy_input, device=device)
            self.dummy_input = dummy_input
            try:
                trace, _ = jit.get_trace_graph(model_clone,
                                               dummy_input,
                                               _force_outplace=True)
            except AttributeError:
                trace, _ = jit._get_trace_graph(model_clone,
                                                dummy_input,
                                                _force_outplace=True)

            # As of PyTorch 1.3.0, ONNX trace optimization has an issue that results in incorrect scope names
            # of nodes in the trace graph.
            # These can make it impossible, in some cases, to derive the connectivity of the model using the original
            # module names. So we try to detect these cases and apply workarounds

            # The issue is:
            #   Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op
            #   gets the scope name of the dropout op
            pre_dropout_nodes_scope_names = OrderedDict()

            prev_non_dropout_op = None
            for node in trace.graph().nodes():
                kind = node.kind()
                if 'aten' not in kind:
                    continue
                if kind == 'aten::dropout':
                    if prev_non_dropout_op:
                        pre_dropout_nodes_scope_names[node.scopeName(
                        )] = prev_non_dropout_op.scopeName()
                else:
                    prev_non_dropout_op = node

            # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
            # composing a GEMM operation; etc.
            torch.onnx._optimize_trace(trace,
                                       torch.onnx.OperatorExportTypes.ONNX)

            graph = trace.graph()
            self.ops = OrderedDict()
            self.module_ops_map = defaultdict(list)
            self.params = OrderedDict()
            self.edges = []
            self.temp = OrderedDict()

            in_out = list(graph.inputs()) + list(graph.outputs())
            for param in in_out:
                self.__add_param(param)

            for node in graph.nodes():
                new_op = self.__create_op(node)

                if apply_scope_name_workarounds:
                    # Here we apply the workaround to the issue of dropout op scope name overriding previous op's
                    # scope name
                    if new_op['name'] in pre_dropout_nodes_scope_names:
                        new_op['orig-name'] = pre_dropout_nodes_scope_names[
                            new_op['name']]
                        new_op['name'] = new_op['orig-name']

                # Convert the graph node's scope name to a PyTorch module name
                module_name = onnx_name_2_pytorch_name(new_op['orig-name'])

                # Get name from before conversion to DistillerModuleList
                module_name = converted_module_names_map[module_name]

                if len(module_name) == 0:
                    # Special case where the module name is an empty string - this happens
                    # when the op is called from the "top-level" of the model
                    new_op['name'] = 'top_level_op'
                else:
                    new_op['name'] = module_name

                # Save the calling module name in the op dict. Denormalize it so it can
                # be directly matched with the actual model
                module_name = distiller.denormalize_module_name(
                    self._src_model, module_name)
                new_op['module-name'] = module_name

                # The node's scope name in the graph corresponds to the module from which the op was called.
                # This means that when ops are invoked from the same module via functional calls or direct
                # operations on tensors, these ops will have the SAME MODEL NAME associated with them.
                # For example:
                #   t = t1 + t2
                #   t = F.relu(t)
                # In this case the add operation and the ReLU operation will have the same name, which is
                # derived from the module they're contained in.
                #
                # Another case where different ops will have the same module name is when a module is reused:
                #   out = self.conv1(x)
                #   out = self.relu(out)    <=== First use of self.relu
                #   out = self.conv2(out)
                #   out = self.relu(out)    <=== Second use of self.relu
                # In this case the graph will have 2 distinct ReLU nodes, with the same scope name.
                #
                # Operators with the same name create very confusing graphs (in ResNet, for example),
                # so we "unroll" them.
                same_module_cnt = len(self.module_ops_map[module_name])
                if same_module_cnt:
                    # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not
                    #       applied to the first module that had the same name
                    new_op['name'] += "_%s_%d" % (new_op['type'],
                                                  same_module_cnt)
                self.module_ops_map[module_name].append(new_op['name'])

                # Finally we register the new op in the ops collection
                self.ops[new_op['name']] = new_op

                for input_ in node.inputs():
                    self.__add_input(new_op, input_)
                    self.edges.append(
                        SummaryGraph.Edge(input_.debugName(), new_op['name']))

                for output in node.outputs():
                    self.__add_output(new_op, output)
                    self.edges.append(
                        SummaryGraph.Edge(new_op['name'], output.debugName()))

                new_op['attrs'] = OrderedDict([
                    (attr_name, node[attr_name])
                    for attr_name in node.attributeNames()
                ])

        self.__merge_pad_avgpool()
        self.add_macs_attr()
        self.add_footprint_attr()
        self.add_arithmetic_intensity_attr()
        del trace
        del graph
        del model_clone
Exemplo n.º 2
0
def _get_scoped_trace_graph(
    module: nn.Module,
    inputs: Tuple[object, ...],
    aliases: Dict[Union[str, nn.Module], str],
) -> torch._C.Graph:  # pyre-ignore[11]
    """
    Traces the provided module using torch.jit._get_trace_graph, but adds
    submodule scope information to each graph node. The resulting graph
    is in-lined and has all model parameters treated as inputs. The input
    model has the scope name '', while its descendants have names of the
    form 'child.grandchild.grandgrandchild...'.

    Args:
        model (nn.Module) : The module to trace
        inputs (tuple) : Inputs used during the trace of the model
        aliases (dict(str or nn.Module, str) : maps modules and module
            names to the canonical name to be used as the scope for
            that module.

    Returns:
        graph (torch._C.Graph) : The pytorch JIT trace of the model
    """
    class ScopePushHook(object):
        def __init__(self, name: str) -> None:
            self.name = name

        def __call__(self, module: nn.Module,
                     inputs: Tuple[object, ...]) -> Tuple[object, ...]:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.push_scope(self.name)
            return inputs

    class ScopePopHook(object):
        def __call__(
            self,
            module: nn.Module,
            inputs: Tuple[object, ...],
            outputs: Tuple[object, ...],
        ) -> Tuple[object, ...]:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.pop_scope()
            return outputs

    seen = set()
    hook_handles = []  # type: List[Any]

    def register_hooks(mod: nn.Module, name: str) -> None:
        prehook = mod.register_forward_pre_hook(
            ScopePushHook(name))  # pyre-ignore[16]
        posthook = mod.register_forward_hook(ScopePopHook())  # pyre-ignore[16]
        hook_handles.append(prehook)
        hook_handles.append(posthook)

    # Torch script does not support parallel torch models, but we still
    # want the scope names to be correct for the complete module.
    if isinstance(
            module,
        (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):

        # Since DataParallel just wraps the model, add an extra set of hooks
        # to the model it wraps to account for the wrapper. Then trace it.
        root_name = aliases[module]
        module = module.module
        register_hooks(module, root_name)

    # We don't need the duplication here, but self._model.named_modules()
    # gives slightly different results for some wrapped models.
    for name, mod in _named_modules_with_dup(module):
        if mod not in seen:
            name = aliases[mod]
            register_hooks(mod, name)
            seen.add(mod)

    if hasattr(torch.jit, "get_trace_graph"):
        trace, _ = torch.jit.get_trace_graph(module, inputs)
        graph = trace.graph()
    else:
        graph, _ = _get_trace_graph(module, inputs)

    for handle in hook_handles:
        handle.remove()

    return graph
Exemplo n.º 3
0
def _get_scoped_trace_graph(
    module: nn.Module,
    inputs: Union[Tensor, Tuple[Tensor, ...]],
    aliases: Dict[Union[str, nn.Module], str],
) -> torch._C.Graph:  # pyre-ignore[11]
    """
    Traces the provided module using torch.jit._get_trace_graph, but adds
    submodule scope information to each graph node. The resulting graph
    is in-lined and has all model parameters treated as inputs. The input
    model has the scope name '', while its descendants have names of the
    form 'child.grandchild.grandgrandchild...'.

    Args:
        model (nn.Module) : The module to trace
        inputs (tuple) : Inputs used during the trace of the model
        aliases (dict(str or nn.Module, str) : maps modules and module
            names to the canonical name to be used as the scope for
            that module.

    Returns:
        graph (torch._C.Graph) : The pytorch JIT trace of the model
    """

    class ScopePushHook:
        def __init__(self, name: str) -> None:
            self.name = name

        def __call__(self, module: nn.Module, inputs: Any) -> Any:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.push_scope(self.name)
            return inputs

    class ScopePopHook:
        def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
            tracing_state = torch._C._get_tracing_state()
            if tracing_state:
                tracing_state.pop_scope()
            return outputs

    hook_handles: List[Any] = []

    def register_hooks(mod: nn.Module, name: str) -> None:
        prehook = mod.register_forward_pre_hook(ScopePushHook(name))
        posthook = mod.register_forward_hook(ScopePopHook())
        hook_handles.append(prehook)
        hook_handles.append(posthook)

    # Unwrap DDP, but correct the scope names for the root module.
    if isinstance(
        module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
    ):
        # Since DataParallel just wraps the model, add an extra set of hooks
        # to the model it wraps to account for the wrapper. Then trace it.
        root_name = aliases[module]
        module = module.module
        register_hooks(module, root_name)

    for name, mod in _named_modules_without_dup(module):
        name = aliases[mod]
        register_hooks(mod, name)

    graph, _ = _get_trace_graph(module, inputs)

    for handle in hook_handles:
        handle.remove()

    return graph