Example #1
0
def load_graph(fname):
    with open(fname, "rb") as f:
        data = f.read()
        model_proto = onnx.ModelProto()
        model_proto.ParseFromString(data)
        onnx_nodes = model_proto.graph.node
        output_names = []

        # some pytorch model had empty names - make one up
        for node in onnx_nodes:
            if not node.name:
                node.name = tf2onnx.utils.make_name("was_empty")

        g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names)
        for i in model_proto.graph.initializer:
            v = numpy_helper.to_array(i)
            name = i.name
            g.initializers[name] = i
            dtype = i.data_type
            g.set_dtype(name, dtype)
            g.set_shape(name, v.shape)
        for i in model_proto.graph.input:
            name = i.name
            if name in g.initializers:
                # ignore if it is not a model input
                continue
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            g.add_graph_input(name, dtype, shape)
        for i in model_proto.graph.output:
            name = i.name
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            output_names.append(name)

        # TODO: this is a hack in case a output name does not follow tensorflow convention
        for node in g.get_nodes():
            for name in node.output:
                g._nodes_by_name[name] = node  # pylint: disable=protected-access

    return g, model_proto.producer_name
Example #2
0
 def _create_empty_graph(self, inputs, shapes, dtypes):
     graph = Graph([], target=self.config.target, opset=self.config.opset)
     for inp, shape, dtype in zip(inputs, shapes, dtypes):
         graph.add_graph_input(inp, dtype, shape)
     return graph
Example #3
0
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)
            nodes_to_remove.extend(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                body_g.add_graph_input(input_name, dtype, loop_input_shape)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                node = body_g.make_node("Identity",
                                        inputs=[o],
                                        shapes=[body_g.get_shape(o)],
                                        dtypes=[body_g.get_dtype(o)])
                new_output_names.append(node.output[0])
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.outputs = new_output_names
            graph = body_g.make_graph("scan body graph",
                                      graph_name=scan_node.name +
                                      "_body_graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)

        return nodes