def load_graph(fname): with open(fname, "rb") as f: data = f.read() model_proto = onnx.ModelProto() model_proto.ParseFromString(data) onnx_nodes = model_proto.graph.node output_names = [] # some pytorch model had empty names - make one up for node in onnx_nodes: if not node.name: node.name = tf2onnx.utils.make_name("was_empty") g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names) for i in model_proto.graph.initializer: v = numpy_helper.to_array(i) name = i.name g.initializers[name] = i dtype = i.data_type g.set_dtype(name, dtype) g.set_shape(name, v.shape) for i in model_proto.graph.input: name = i.name if name in g.initializers: # ignore if it is not a model input continue shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) g.add_graph_input(name, dtype, shape) for i in model_proto.graph.output: name = i.name shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) output_names.append(name) # TODO: this is a hack in case a output name does not follow tensorflow convention for node in g.get_nodes(): for name in node.output: g._nodes_by_name[name] = node # pylint: disable=protected-access return g, model_proto.producer_name
def _create_empty_graph(self, inputs, shapes, dtypes): graph = Graph([], target=self.config.target, opset=self.config.opset) for inp, shape, dtype in zip(inputs, shapes, dtypes): graph.add_graph_input(inp, dtype, shape) return graph
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) log.debug("start creating body graph for scan node %s ", scan_node.name) const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) nodes_to_remove.extend(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) body_g.add_graph_input(input_name, dtype, loop_input_shape) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. node = body_g.make_node("Identity", inputs=[o], shapes=[body_g.get_shape(o)], dtypes=[body_g.get_dtype(o)]) new_output_names.append(node.output[0]) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.outputs = new_output_names graph = body_g.make_graph("scan body graph", graph_name=scan_node.name + "_body_graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) return nodes