def parse(graph, args=None, omit_useless_nodes=True): """This method parses an optimized PyTorch model graph and produces a list of nodes and node stats for eventual conversion to TensorBoard protobuf format. Args: graph (PyTorch module): The model to be parsed. args (tuple): input tensor[s] for the model. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ n_inputs = len(args) scope = {} nodes_py = GraphPy() for i, node in enumerate(graph.inputs()): if omit_useless_nodes: if len( node.uses() ) == 0: # number of user of the node (= number of outputs/ fanout) continue if i < n_inputs: nodes_py.append(NodePyIO(node, 'input')) else: nodes_py.append(NodePyIO(node)) # parameter for node in graph.nodes(): nodes_py.append(NodePyOP(node)) for node in graph.outputs(): # must place last. NodePyIO(node, 'output') nodes_py.find_common_root() nodes_py.populate_namespace_from_OP_to_IO() return nodes_py
def parse(self, graph, trace, args=None, omit_useless_nodes=True): """This method parses an optimized PyTorch model graph and produces a list of nodes and node stats for eventual conversion to TensorBoard protobuf format. Args: graph (PyTorch module): The model graph to be parsed. trace (PyTorch JIT TracedModule): The model trace to be parsed. args (tuple): input tensor[s] for the model. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ nodes_py = GraphPy() for node in graph.inputs(): if omit_useless_nodes: if not node.uses( ): # number of user of the node (= number of outputs/ fanout) continue if node.type().kind() != CLASSTYPE_KIND: nodes_py.append(NodePyIO(node, 'input')) attr_to_scope = dict() def node_to_name(d): return str(d).split(":")[0].strip() for node in graph.nodes(): if node.kind() == GETATTR_KIND: attr_name = node.s('name') node_name = node_to_name(node) parent = node.input().node() # If the parent node is not the top-level "self" node if parent.kind() == GETATTR_KIND: parent_scope = attr_to_scope[node_to_name(parent)] attr_scope = parent_scope.split('/')[-1] attr_to_scope[node_name] = '{}/{}.{}'.format( parent_scope, attr_scope, attr_name) else: attr_to_scope[node_name] = '__module.{}'.format(attr_name) # We don't need classtype nodes; scope will provide this information if node.output().type().kind() != CLASSTYPE_KIND: node_py = NodePyOP(node) node_py.scopeName = attr_to_scope[node_name] nodes_py.append(node_py) else: nodes_py.append(NodePyOP(node)) # Create sink nodes for output ops for i, node in enumerate(graph.outputs()): node_py = NodePyIO(node, 'output') node_py.debugName = "output.{}".format(i + 1) node_py.inputs = [node.debugName()] nodes_py.append(node_py) alias_to_name = dict() base_name = parse_traced_name(trace._name) for name, module in trace.named_modules(prefix='__module'): mod_name = parse_traced_name(module._name) attr_name = name.split('.')[-1] alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name) for node in nodes_py.nodes_op: module_aliases = node.scopeName.split('/')[-1].split('.') module_name = '' for i, alias in enumerate(module_aliases): if i == 0: module_name = alias node.scopeName = base_name else: module_name += '.' + alias node.scopeName += '/' + \ (alias_to_name[module_name] if module_name in alias_to_name else alias) nodes_py.populate_namespace_from_OP_to_IO() return nodes_py.to_proto()
def _build_graph(self): """ Build graph using our defined format from jit trace. There are basically three steps: first, construct necessary information (data structures), second, extract all the modules to convert to node, Third, extract all functions to convert to node. Returns ------- dict use name to index nodes, key: node name, value: node dict use input (its name) to index nodes, key: input, value: list of nodes that take this input dict use output (its name) to index nodes, key: output, value: node that generates this output """ omit_useless_nodes = True graph = self.trace.graph _logger.debug(graph) # build output mapping, from output debugName to its node output_to_node = { x.debugName(): n for n in graph.nodes() for x in n.outputs() } # build input mapping, from input debugName to its node input_to_node = { x.debugName(): n for n in graph.nodes() for x in n.inputs() } # build module mapping, from module name to all nodes (as list) under this module scope module_to_nodes = defaultdict(list) # the mapping of function (non-module in forward) to nodes, key is scope name func_to_nodes = defaultdict(list) nodes_py = GraphPy() for node in graph.inputs(): if omit_useless_nodes: if not node.uses( ): # number of user of the node (= number of outputs/ fanout) continue if node.type().kind() != 'ClassType': nodes_py.append(NodePyIO(node, 'input')) self.leaf_modules = self._extract_leaf_modules() module_to_type = { name: parse_traced_name(module._name) for name, module in self.trace.named_modules() } # associate module name with their trace graph nodes for node in graph.nodes(): module_name = self._get_module_name(node.scopeName()) if module_name in self.leaf_modules: module_to_nodes[module_name].append(node) else: func_to_nodes[node.scopeName()].append(node) # build node group for module for module_name, node_cpps in module_to_nodes.items(): use_count = 0 merged = set() for node in node_cpps: if node not in merged: # modules that have same scope name may have different locations in the # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, # so we also need to call the expand_module_node. unique_name = module_name if use_count > 0: unique_name = module_name + '.%d' % use_count node_group = self._expand_module_node( node, module_name, unique_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node, 'module') nodes_py.nodes_op.append(node_group) use_count += 1 merged.update(node_group.node_cpps) # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional for _, nodes in func_to_nodes.items(): # extract non prim:: nodes non_prim_nodes = list() for node in nodes: if not node.kind().startswith('prim::'): non_prim_nodes.append(node) # for each non prim node, expand it for node in non_prim_nodes: node_group = self._expand_non_prim_node( node, nodes, input_to_node, output_to_node, 'func') nodes_py.nodes_op.append(node_group) # get shape infor for view (aten::view) func # if node_group.op_type in ['aten::view', 'aten::flatten']: # node_group.auxiliary = self._extract_shape_info(node) for node in graph.outputs(): # Create sink nodes for output ops node_py = NodePyIO(node, 'output') nodes_py.append(node_py) self.nodes_py = nodes_py # build index return self._build_index(self.nodes_py.nodes_op)