def process_tf_graph(graph, continue_on_error=False, verbose=False, target=None, opset=0): """Convert tensorflow graph to onnx graph.""" if target is None: target = DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(graph) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset) ops = g.get_nodes() # rewrites for rewrite in [rewrite_flatten, rewrite_random_uniform, rewrite_random_normal, rewrite_dropout, rewrite_transpose]: ops = rewrite(g, ops) g.set_nodes(ops) g.topological_sort(g.get_nodes()) mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error) g.topological_sort(g.get_nodes()) g.update_proto() if verbose: print("tensorflow ops: {}".format(op_cnt)) print("tensorflow attr: {}".format(attr_cnt)) print("onnx mapped: {}".format(mapped_op)) print("onnx unmapped: {}".format(unmapped_op)) return g
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None, opset=None, custom_op_handlers=None, custom_rewriter=None): """Convert tensorflow graph to onnx graph. Args: tf_graph: tensorflow graph continue_on_error: if an op can't be processed (aka there is no mapping), continue verbose: print summary stats target: list of workarounds applied to help certain platforms opset: the opset to be used (int, default is latest) custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters Return: onnx graph """ def topological_sort(ops): if not continue_on_error: g.topological_sort(ops) else: try: g.topological_sort(ops) except: # if we continue on error, ignore graph cycles so we can report all missing ops pass if target is None: target = DEFAULT_TARGET onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph) g = Graph(onnx_nodes, output_shapes, dtypes, target, opset) ops = g.get_nodes() # rewrite graph rewriters = [rewrite_transpose, rewrite_flatten, rewrite_random_uniform, rewrite_random_normal, rewrite_dropout] if custom_rewriter is not None: rewriters.extend(custom_rewriter) for rewrite in rewriters: ops = rewrite(g, ops) g.set_nodes(ops) topological_sort(g.get_nodes()) if custom_op_handlers is None: custom_op_handlers = {} mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers) topological_sort(g.get_nodes()) g.update_proto() if verbose: print("tensorflow ops: {}".format(op_cnt)) print("tensorflow attr: {}".format(attr_cnt)) print("onnx mapped: {}".format(mapped_op)) print("onnx unmapped: {}".format(unmapped_op)) return g
def test_insert_node2(self): model_proto = self.sample_net() nodes = model_proto.node g = Graph(nodes, output_shapes={}, dtypes={}) n7 = g.insert_new_node_on_output("Abs", "n1:0", name="n7") ops = g.get_nodes() ops.append(n7) g.topological_sort(ops) result = onnx_to_graphviz(g) expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ' \ 'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \ 'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }' self.assertEqual(expected, result)
def load_graph(fname): with open(fname, "rb") as f: data = f.read() model_proto = onnx.ModelProto() model_proto.ParseFromString(data) onnx_nodes = model_proto.graph.node output_names = [] # some pytorch model had empty names - make one up for node in onnx_nodes: if not node.name: node.name = tf2onnx.utils.make_name("was_empty") g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names) for i in model_proto.graph.initializer: v = numpy_helper.to_array(i) name = i.name g.initializers[name] = i dtype = i.data_type g.set_dtype(name, dtype) g.set_shape(name, v.shape) for i in model_proto.graph.input: name = i.name if name in g.initializers: # ignore if it is not a model input continue shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) g.add_graph_input(name, dtype, shape) for i in model_proto.graph.output: name = i.name shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1 for j in i.type.tensor_type.shape.dim] dtype = i.type.tensor_type.elem_type g.set_dtype(name, dtype) g.set_shape(name, shape) output_names.append(name) # TODO: this is a hack in case a output name does not follow tensorflow convention for node in g.get_nodes(): for name in node.output: g._nodes_by_name[name] = node # pylint: disable=protected-access return g, model_proto.producer_name
def rewrite(self): log.debug("enter custom rnn late rewriter") nodes = self.g.get_nodes() nodes_to_remove = [] for scan_node in nodes: if scan_node.type != "Scan": continue log.debug("late write for scan node %s", scan_node.name) num_scan_inputs = scan_node.get_attr("num_scan_inputs").i if not BodyGraphDict.has_body_graph_info(scan_node.name): continue body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name) onnx_nodes, _ = LoopRewriterBase.find_subgraph( body_graph_meta, self.g) nodes_to_remove.extend(onnx_nodes) log.debug("start creating body graph for scan node %s ", scan_node.name) body_graph_initializers = {} const_nodes = [ n for n in onnx_nodes if n.type in ("Const", "ConstV2") ] for n in const_nodes: # when set nodes, Const should be removed, they need be replaced as initializers. body_graph_initializers[n.output[0]] = self.g.initializers[ n.output[0]] onnx_nodes.remove(n) onnx_nodes = set(onnx_nodes) ops = [] for op in onnx_nodes: onnx_op = op.op ops.append(onnx_op) body_g = Graph(ops, output_shapes=self.g._output_shapes, dtypes=self.g._dtypes) body_g._initializers = body_graph_initializers log.debug("start preparing body graph inputs nodes") temp_nodes = body_g.get_nodes() i = 0 input_count = len(body_graph_meta.input_ids) for input_name, init_input_id in zip( body_graph_meta.input_ids, body_graph_meta.initial_input_ids): shape = body_g.get_shape(input_name) dtype = body_g.get_dtype(input_name) if shape is None: shape = self.g.get_shape(init_input_id) if i >= input_count - num_scan_inputs: loop_input_shape = list(shape)[2:] # delete [1, time,] else: loop_input_shape = list(shape) else: loop_input_shape = list(shape) onnx_input_shape = utils.make_onnx_shape(loop_input_shape) val = helper.make_tensor_value_info(input_name, dtype, onnx_input_shape) body_g.add_model_input(input_name, val) i += 1 log.debug("start preparing body graph outputs nodes") new_output_names = [] for o in body_graph_meta.output_ids: # insert identity node, since sometimes we need output same output_id as state_output # and scan_out, but ONNX don't allow the same output_id appeared more than once as # output node. identity_name = utils.make_name("Identity") identity_output = utils.port_name(identity_name) node = Node( helper.make_node("Identity", [o], [identity_output], name=identity_name), body_g) body_g.set_dtype(identity_output, body_g.get_dtype(o)) body_g.copy_shape(o, identity_output) new_output_names.append(identity_output) temp_nodes.append(node) body_g.set_nodes(temp_nodes) body_g.topological_sort(body_g.get_nodes()) log.debug("start make graph based on body graph nodes") body_g.output_names = new_output_names graph = body_g.make_graph("scan body graph") scan_node.set_attr("body", graph) # remove nodes in body graph from g for n in set(nodes_to_remove): if n in nodes: nodes.remove(n) elif self.g.is_initializer(n.output[0]): del self.g.initializers[n.output[0]] else: raise ValueError("error when removing nodes") return nodes