def load_graph(fname):
    with open(fname, "rb") as f:
        data = f.read()
        model_proto = onnx.ModelProto()
        model_proto.ParseFromString(data)
        onnx_nodes = model_proto.graph.node
        output_names = []

        # some pytorch model had empty names - make one up
        for node in onnx_nodes:
            if not node.name:
                node.name = tf2onnx.utils.make_name("was_empty")

        g = Graph(onnx_nodes, output_shapes={}, dtypes={}, output_names=output_names)
        for i in model_proto.graph.initializer:
            v = numpy_helper.to_array(i)
            name = i.name
            g.initializers[name] = i
            dtype = i.data_type
            g.set_dtype(name, dtype)
            g.set_shape(name, v.shape)
        for i in model_proto.graph.input:
            name = i.name
            if name in g.initializers:
                # ignore if it is not a model input
                continue
            g.add_model_input(name, i)
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
        for i in model_proto.graph.output:
            name = i.name
            shape = [j.dim_value if hasattr(i.type.tensor_type, "dim_value") else -1
                     for j in i.type.tensor_type.shape.dim]
            dtype = i.type.tensor_type.elem_type
            g.set_dtype(name, dtype)
            g.set_shape(name, shape)
            output_names.append(name)

        # TODO: this is a hack in case a output name does not follow tensorflow convention
        for node in g.get_nodes():
            for name in node.output:
                g._nodes_by_name[name] = node  # pylint: disable=protected-access

    return g, model_proto.producer_name
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes