def load_graph(fname): with open(fname, "rb") as f: data = f.read() model_proto = onnx.ModelProto() model_proto.ParseFromString(data) onnx_nodes = model_proto.graph.node output_names = [] # some pytorch model had empty names - make one up for node in onnx_nodes: if not node.name: node.name = tf2onnx.utils.make_name("was_empty") g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names) for i in model_proto.graph.initializer: v = numpy_helper.to_array(i) name = i.name g.initializers[name] = i dtype = i.data_type g.set_dtype(name, dtype) g.set_shape(name, v.shape) for i in model_proto.graph.input: name = i.name if name in g.initializers: # ignore if it is not a model input continue g.add_model_input(name, i) shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) for i in model_proto.graph.output: name = i.name shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) output_names.append(name) # TODO: this is a hack in case a output name does not follow tensorflow convention for node in g.get_nodes(): for name in node.output: g._nodes_by_name[name] = node # pylint: disable=protected-access return g, model_proto.producer_name
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) nodes_to_remove.extend(onnx_nodes) log.debug("start creating body graph for scan node %s ", scan_node.name) body_graph_initializers = {} const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. body_graph_initializers[n.output[0]] = self.g.initializers[ n.output[0]] onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) body_g._initializers = body_graph_initializers log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) onnx_input_shape = utils.make_onnx_shape(loop_input_shape) val = helper.make_tensor_value_info(input_name, dtype, onnx_input_shape) body_g.add_model_input(input_name, val) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. identity_name = utils.make_name("Identity") identity_output = utils.port_name(identity_name) node = Node( helper.make_node("Identity", [o], [identity_output], name=identity_name), body_g) body_g.set_dtype(identity_output, body_g.get_dtype(o)) body_g.copy_shape(o, identity_output) new_output_names.append(identity_output) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.output_names = new_output_names graph = body_g.make_graph("scan body graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) elif self.g.is_initializer(n.output[0]): del self.g.initializers[n.output[0]] else: raise ValueError("error when removing nodes") return nodes