コード例 #1
0
def process_tf_graph(graph, continue_on_error=False, verbose=False, target=None, opset=0):
    """Convert tensorflow graph to onnx graph."""

    if target is None:
        target = DEFAULT_TARGET

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(graph)

    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset)
    ops = g.get_nodes()

    # rewrites
    for rewrite in [rewrite_flatten,
                    rewrite_random_uniform,
                    rewrite_random_normal,
                    rewrite_dropout,
                    rewrite_transpose]:
        ops = rewrite(g, ops)
        g.set_nodes(ops)

    g.topological_sort(g.get_nodes())
    mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error)
    g.topological_sort(g.get_nodes())

    g.update_proto()
    if verbose:
        print("tensorflow ops: {}".format(op_cnt))
        print("tensorflow attr: {}".format(attr_cnt))
        print("onnx mapped: {}".format(mapped_op))
        print("onnx unmapped: {}".format(unmapped_op))
    return g
コード例 #2
0
def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=None,
                     opset=None, custom_op_handlers=None, custom_rewriter=None):
    """Convert tensorflow graph to onnx graph.
        Args:
            tf_graph: tensorflow graph
            continue_on_error: if an op can't be processed (aka there is no mapping), continue
            verbose: print summary stats
            target: list of workarounds applied to help certain platforms
            opset: the opset to be used (int, default is latest)
            custom_op_handlers: dictionary of custom ops handlers
            custom_rewriter: list of custom graph rewriters
        Return:
            onnx graph
    """
    def topological_sort(ops):
        if not continue_on_error:
            g.topological_sort(ops)
        else:
            try:
                g.topological_sort(ops)
            except:
                # if we continue on error, ignore graph cycles so we can report all missing ops
                pass

    if target is None:
        target = DEFAULT_TARGET

    onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tensorflow_to_onnx(tf_graph)

    g = Graph(onnx_nodes, output_shapes, dtypes, target, opset)
    ops = g.get_nodes()

    # rewrite graph
    rewriters = [rewrite_transpose, rewrite_flatten, rewrite_random_uniform,
                 rewrite_random_normal, rewrite_dropout]
    if custom_rewriter is not None:
        rewriters.extend(custom_rewriter)
    for rewrite in rewriters:
        ops = rewrite(g, ops)
        g.set_nodes(ops)
    topological_sort(g.get_nodes())

    if custom_op_handlers is None:
        custom_op_handlers = {}
    mapped_op, unmapped_op = tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers)
    topological_sort(g.get_nodes())
    g.update_proto()
    if verbose:
        print("tensorflow ops: {}".format(op_cnt))
        print("tensorflow attr: {}".format(attr_cnt))
        print("onnx mapped: {}".format(mapped_op))
        print("onnx unmapped: {}".format(unmapped_op))
    return g
コード例 #3
0
 def test_insert_node2(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = Graph(nodes, output_shapes={}, dtypes={})
     n7 = g.insert_new_node_on_output("Abs", "n1:0", name="n7")
     ops = g.get_nodes()
     ops.append(n7)
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n7 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ' \
                'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
                'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5:0 -> n6 }'
     self.assertEqual(expected, result)
コード例 #4
0
def load_graph(fname):
    with open(fname, "rb") as f:
        data = f.read()
        model_proto = onnx.ModelProto()
        model_proto.ParseFromString(data)
        onnx_nodes = model_proto.graph.node
        output_names = []

        # some pytorch model had empty names - make one up
        for node in onnx_nodes:
            if not node.name:
                node.name = tf2onnx.utils.make_name("was_empty")

        g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names)
        for i in model_proto.graph.initializer:
            v = numpy_helper.to_array(i)
            name = i.name
            g.initializers[name] = i
            dtype = i.data_type
            g.set_dtype(name, dtype)
            g.set_shape(name, v.shape)
        for i in model_proto.graph.input:
            name = i.name
            if name in g.initializers:
                # ignore if it is not a model input
                continue
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            g.add_graph_input(name, dtype, shape)
        for i in model_proto.graph.output:
            name = i.name
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            output_names.append(name)

        # TODO: this is a hack in case a output name does not follow tensorflow convention
        for node in g.get_nodes():
            for name in node.output:
                g._nodes_by_name[name] = node  # pylint: disable=protected-access

    return g, model_proto.producer_name
コード例 #5
0
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes