def __init__(self, model, dummy_input, apply_scope_name_workarounds=True): self._src_model = model self._named_modules = OrderedDict(model.named_modules()) self._adj_map = None self._layers_topological_order = None self._top_level_ops = set() model_clone = distiller.make_non_parallel_copy(model) # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList # See documentation of _DistillerModuleList class for details on why this is done model_clone, converted_module_names_map = _to_distiller_modulelist( model_clone) with torch.onnx.set_training(model_clone, False): device = distiller.model_device(model_clone) dummy_input = distiller.convert_tensors_recursively_to( dummy_input, device=device) self.dummy_input = dummy_input try: trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True) except AttributeError: trace, _ = jit._get_trace_graph(model_clone, dummy_input, _force_outplace=True) # As of PyTorch 1.3.0, ONNX trace optimization has an issue that results in incorrect scope names # of nodes in the trace graph. # These can make it impossible, in some cases, to derive the connectivity of the model using the original # module names. So we try to detect these cases and apply workarounds # The issue is: # Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op # gets the scope name of the dropout op pre_dropout_nodes_scope_names = OrderedDict() prev_non_dropout_op = None for node in trace.graph().nodes(): kind = node.kind() if 'aten' not in kind: continue if kind == 'aten::dropout': if prev_non_dropout_op: pre_dropout_nodes_scope_names[node.scopeName( )] = prev_non_dropout_op.scopeName() else: prev_non_dropout_op = node # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes # composing a GEMM operation; etc. torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) graph = trace.graph() self.ops = OrderedDict() self.module_ops_map = defaultdict(list) self.params = OrderedDict() self.edges = [] self.temp = OrderedDict() in_out = list(graph.inputs()) + list(graph.outputs()) for param in in_out: self.__add_param(param) for node in graph.nodes(): new_op = self.__create_op(node) if apply_scope_name_workarounds: # Here we apply the workaround to the issue of dropout op scope name overriding previous op's # scope name if new_op['name'] in pre_dropout_nodes_scope_names: new_op['orig-name'] = pre_dropout_nodes_scope_names[ new_op['name']] new_op['name'] = new_op['orig-name'] # Convert the graph node's scope name to a PyTorch module name module_name = onnx_name_2_pytorch_name(new_op['orig-name']) # Get name from before conversion to DistillerModuleList module_name = converted_module_names_map[module_name] if len(module_name) == 0: # Special case where the module name is an empty string - this happens # when the op is called from the "top-level" of the model new_op['name'] = 'top_level_op' else: new_op['name'] = module_name # Save the calling module name in the op dict. Denormalize it so it can # be directly matched with the actual model module_name = distiller.denormalize_module_name( self._src_model, module_name) new_op['module-name'] = module_name # The node's scope name in the graph corresponds to the module from which the op was called. # This means that when ops are invoked from the same module via functional calls or direct # operations on tensors, these ops will have the SAME MODEL NAME associated with them. # For example: # t = t1 + t2 # t = F.relu(t) # In this case the add operation and the ReLU operation will have the same name, which is # derived from the module they're contained in. # # Another case where different ops will have the same module name is when a module is reused: # out = self.conv1(x) # out = self.relu(out) <=== First use of self.relu # out = self.conv2(out) # out = self.relu(out) <=== Second use of self.relu # In this case the graph will have 2 distinct ReLU nodes, with the same scope name. # # Operators with the same name create very confusing graphs (in ResNet, for example), # so we "unroll" them. same_module_cnt = len(self.module_ops_map[module_name]) if same_module_cnt: # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not # applied to the first module that had the same name new_op['name'] += "_%s_%d" % (new_op['type'], same_module_cnt) self.module_ops_map[module_name].append(new_op['name']) # Finally we register the new op in the ops collection self.ops[new_op['name']] = new_op for input_ in node.inputs(): self.__add_input(new_op, input_) self.edges.append( SummaryGraph.Edge(input_.debugName(), new_op['name'])) for output in node.outputs(): self.__add_output(new_op, output) self.edges.append( SummaryGraph.Edge(new_op['name'], output.debugName())) new_op['attrs'] = OrderedDict([ (attr_name, node[attr_name]) for attr_name in node.attributeNames() ]) self.__merge_pad_avgpool() self.add_macs_attr() self.add_footprint_attr() self.add_arithmetic_intensity_attr() del trace del graph del model_clone
def _get_scoped_trace_graph( module: nn.Module, inputs: Tuple[object, ...], aliases: Dict[Union[str, nn.Module], str], ) -> torch._C.Graph: # pyre-ignore[11] """ Traces the provided module using torch.jit._get_trace_graph, but adds submodule scope information to each graph node. The resulting graph is in-lined and has all model parameters treated as inputs. The input model has the scope name '', while its descendants have names of the form 'child.grandchild.grandgrandchild...'. Args: model (nn.Module) : The module to trace inputs (tuple) : Inputs used during the trace of the model aliases (dict(str or nn.Module, str) : maps modules and module names to the canonical name to be used as the scope for that module. Returns: graph (torch._C.Graph) : The pytorch JIT trace of the model """ class ScopePushHook(object): def __init__(self, name: str) -> None: self.name = name def __call__(self, module: nn.Module, inputs: Tuple[object, ...]) -> Tuple[object, ...]: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.push_scope(self.name) return inputs class ScopePopHook(object): def __call__( self, module: nn.Module, inputs: Tuple[object, ...], outputs: Tuple[object, ...], ) -> Tuple[object, ...]: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.pop_scope() return outputs seen = set() hook_handles = [] # type: List[Any] def register_hooks(mod: nn.Module, name: str) -> None: prehook = mod.register_forward_pre_hook( ScopePushHook(name)) # pyre-ignore[16] posthook = mod.register_forward_hook(ScopePopHook()) # pyre-ignore[16] hook_handles.append(prehook) hook_handles.append(posthook) # Torch script does not support parallel torch models, but we still # want the scope names to be correct for the complete module. if isinstance( module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): # Since DataParallel just wraps the model, add an extra set of hooks # to the model it wraps to account for the wrapper. Then trace it. root_name = aliases[module] module = module.module register_hooks(module, root_name) # We don't need the duplication here, but self._model.named_modules() # gives slightly different results for some wrapped models. for name, mod in _named_modules_with_dup(module): if mod not in seen: name = aliases[mod] register_hooks(mod, name) seen.add(mod) if hasattr(torch.jit, "get_trace_graph"): trace, _ = torch.jit.get_trace_graph(module, inputs) graph = trace.graph() else: graph, _ = _get_trace_graph(module, inputs) for handle in hook_handles: handle.remove() return graph
def _get_scoped_trace_graph( module: nn.Module, inputs: Union[Tensor, Tuple[Tensor, ...]], aliases: Dict[Union[str, nn.Module], str], ) -> torch._C.Graph: # pyre-ignore[11] """ Traces the provided module using torch.jit._get_trace_graph, but adds submodule scope information to each graph node. The resulting graph is in-lined and has all model parameters treated as inputs. The input model has the scope name '', while its descendants have names of the form 'child.grandchild.grandgrandchild...'. Args: model (nn.Module) : The module to trace inputs (tuple) : Inputs used during the trace of the model aliases (dict(str or nn.Module, str) : maps modules and module names to the canonical name to be used as the scope for that module. Returns: graph (torch._C.Graph) : The pytorch JIT trace of the model """ class ScopePushHook: def __init__(self, name: str) -> None: self.name = name def __call__(self, module: nn.Module, inputs: Any) -> Any: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.push_scope(self.name) return inputs class ScopePopHook: def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any: tracing_state = torch._C._get_tracing_state() if tracing_state: tracing_state.pop_scope() return outputs hook_handles: List[Any] = [] def register_hooks(mod: nn.Module, name: str) -> None: prehook = mod.register_forward_pre_hook(ScopePushHook(name)) posthook = mod.register_forward_hook(ScopePopHook()) hook_handles.append(prehook) hook_handles.append(posthook) # Unwrap DDP, but correct the scope names for the root module. if isinstance( module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel) ): # Since DataParallel just wraps the model, add an extra set of hooks # to the model it wraps to account for the wrapper. Then trace it. root_name = aliases[module] module = module.module register_hooks(module, root_name) for name, mod in _named_modules_without_dup(module): name = aliases[mod] register_hooks(mod, name) graph, _ = _get_trace_graph(module, inputs) for handle in hook_handles: handle.remove() return graph