예제 #1
0
def convert_frozen_to_onnx(
    settings: SerializationSettings, frozen_graph_def: tf.GraphDef
) -> Any:
    # This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

    inputs = _get_input_node_names(frozen_graph_def)
    outputs = _get_output_node_names(frozen_graph_def)
    logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")

    frozen_graph_def = tf_optimize(
        inputs, outputs, frozen_graph_def, fold_constant=True
    )

    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(frozen_graph_def, name="")
    with tf.Session(graph=tf_graph):
        g = process_tf_graph(
            tf_graph,
            input_names=inputs,
            output_names=outputs,
            opset=settings.onnx_opset,
        )

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model(settings.brain_name)

    return model_proto
def convert_frozen_to_onnx(settings: SerializationSettings,
                           frozen_graph_def: tf.GraphDef) -> Any:
    # This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

    # Some constants in the graph need to be read by the inference system.
    # These aren't used by the model anywhere, so trying to make sure they propagate
    # through conversion and import is a losing battle. Instead, save them now,
    # so that we can add them back later.
    constant_values = {}
    for n in frozen_graph_def.node:
        if n.name in MODEL_CONSTANTS:
            val = n.attr["value"].tensor.int_val[0]
            constant_values[n.name] = val

    inputs = _get_input_node_names(frozen_graph_def)
    outputs = _get_output_node_names(frozen_graph_def)
    logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")

    frozen_graph_def = tf_optimize(inputs,
                                   outputs,
                                   frozen_graph_def,
                                   fold_constant=True)

    with tf.Graph().as_default() as tf_graph:
        tf.import_graph_def(frozen_graph_def, name="")
    with tf.Session(graph=tf_graph):
        g = process_tf_graph(
            tf_graph,
            input_names=inputs,
            output_names=outputs,
            opset=settings.onnx_opset,
        )

    onnx_graph = optimizer.optimize_graph(g)
    model_proto = onnx_graph.make_model(settings.brain_name)

    # Save the constant values back the graph initializer.
    # This will ensure the importer gets them as global constants.
    constant_nodes = []
    for k, v in constant_values.items():
        constant_node = _make_onnx_node_for_constant(k, v)
        constant_nodes.append(constant_node)
    model_proto.graph.initializer.extend(constant_nodes)
    return model_proto