Exemple #1
0
            q_stats.update({node: (args.q_mean[idx], args.q_std[idx])})
        converter.quantized_input_stats = q_stats

    tflite_model = converter.convert()
    open(args.output_dir + ".tflite", "wb").write(tflite_model)
    print("Model successfully converted to tflite flatbuffer")

    # Compile the flatbuffer for edge TPU
    subprocess.run(["edgetpu_compiler", args.output_dir + ".tflite"], check=True)
    print("Model successfully compiled")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    setup_args(parser)
    add_edgetpu_args(parser)
    args = parser.parse_args()

    # Create graph
    with tf.compat.v1.gfile.GFile(args.input, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    # Get graph data
    graph_chars = converter_util.GraphCharacteristics(graph_def)

    # Set input dimensions
    input_dims = converter_util.get_input_dims(args, graph_chars.input_nodes[0])  # TODO: Only supports one input tensor

    convert_to_edgetpu(args, input_dims, graph_chars=graph_chars)
Exemple #2
0
def convert_to_edgetpu(args, input_dims, graph_chars=None):

    if len(args.q_mean) != len(args.q_std):
        print("Error: Number of q_mean arguments ({}) not equal to number of q_std arguments ({})"
              .format(len(args.q_mean), len(args.q_std)))
        return

    if graph_chars is None:
        with tf.compat.v1.gfile.GFile(args.input, "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())

        graph_chars = converter_util.GraphCharacteristics(graph_def)

    if len(graph_chars.input_nodes) != len(args.q_mean):
        print("Error: Number of input nodes ({}) not equal to number of quantization parameters ({})"
              .format(len(graph_chars.input_nodes), len(args.q_mean)))
        return

    output_nodes = []

    # Correct for multiple output dimensions.
    if "TFLite_Detection_PostProcess" in graph_chars.output_node_names:
        output_nodes = ["TFLite_Detection_PostProcess", "TFLite_Detection_PostProcess:1",
                        "TFLite_Detection_PostProcess:2", "TFLite_Detection_PostProcess:3"]
    elif any([node.attr['_output_types'] is not None for node in graph_chars.output_nodes]):
        for node in graph_chars.output_nodes:
            num_out = len(node.attr['_output_types'].list.type)
            if num_out > 0:
                print("Node {} has {} output dimensions".format(node.name, num_out))
                output_nodes = [node.name + ':' + str(i) for i in range(num_out)]
            output_nodes[0] = node.name
    else:
        output_nodes = graph_chars.output_nodes

    print("Corrected output names: ", output_nodes)

    # Check for quantization
    quantized = False
    for node in output_nodes:
        if graph_chars.nodes_by_name[node.split(':')[0]].attr['_output_quantized'].b:
            print("Quantization detected. Using quantized conversion with means and STDs ({}, {})"
                  .format(args.q_mean, args.q_std))
            quantized = True
            break

    # Convert and save model
    converter = tf.lite.TFLiteConverter.from_frozen_graph(args.input, graph_chars.input_node_names, output_nodes,
                                                          input_shapes={graph_chars.input_node_names[0]: input_dims})
    converter.allow_custom_ops = True
    if quantized:
        converter.inference_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8  # TODO: Fix assumption that quantization is 8-bit
        converter.inference_input_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8
        q_stats = {}
        for idx, node in enumerate(graph_chars.input_node_names):
            q_stats.update({node: (args.q_mean[idx], args.q_std[idx])})
        converter.quantized_input_stats = q_stats

    tflite_model = converter.convert()
    open(args.output_dir + ".tflite", "wb").write(tflite_model)
    print("Model successfully converted to tflite flatbuffer")

    # Compile the flatbuffer for edge TPU
    subprocess.run(["edgetpu_compiler", args.output_dir + ".tflite"], check=True)
    print("Model successfully compiled")
Exemple #3
0
def add_plugin(graph, input_dims, graph_chars=None):
    graph_def = graph.as_graph_def()

    if graph_chars is None:
        graph_chars = converter_util.GraphCharacteristics(graph_def)

    num_classes = converter_util.get_num_classes(graph_def,
                                                 graph_chars=graph_chars)
    input_order = converter_util.get_NMS_input_order(graph_def,
                                                     "Postprocessor",
                                                     graph_chars=graph_chars)

    if any(x == -1 for x in input_order):
        print("NMS input order error: {} Aborting".format(input_order))
        exit(1)

    if args.debug:
        print("Detected number of classes: ", num_classes)
        print("Detected NMS input order: ", input_order)

    assert_nodes = graph.find_nodes_by_op("Assert")
    graph.remove(assert_nodes, remove_exclusive_dependencies=True)

    identity_nodes = graph.find_nodes_by_op("Identity")
    graph.forward_inputs(identity_nodes)

    Input = gs.create_plugin_node(name="Input",
                                  op="Placeholder",
                                  shape=(1, ) + input_dims)

    # TODO: Consider automation of parameters
    PriorBox = gs.create_plugin_node(name="MultipleGridAnchorGenerator",
                                     op="GridAnchor_TRT",
                                     minSize=0.2,
                                     maxSize=0.95,
                                     aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33],
                                     variance=[0.1, 0.1, 0.2, 0.2],
                                     featureMapShapes=[19, 10, 5, 3, 2, 1],
                                     numLayers=6)

    NMS = gs.create_plugin_node(name="NMS",
                                op="NMS_TRT",
                                shareLocation=1,
                                varianceEncodedInTarget=0,
                                backgroundLabelId=0,
                                confidenceThreshold=0.3,
                                nmsThreshold=0.6,
                                topK=100,
                                keepTopK=100,
                                numClasses=num_classes,
                                inputOrder=input_order,
                                confSigmoid=1,
                                isNormalized=1)

    concat_box_loc = gs.create_plugin_node("concat_box_loc",
                                           op="FlattenConcat_TRT",
                                           axis=1,
                                           ignoreBatch=0)

    concat_box_conf = gs.create_plugin_node("concat_box_conf",
                                            op="FlattenConcat_TRT",
                                            axis=1,
                                            ignoreBatch=0)

    concat_priorbox = gs.create_node("concat_priorbox", op="ConcatV2", axis=2)

    namespace_map = {
        "MultipleGridAnchorGenerator": PriorBox,
        "Preprocessor": Input,
        "ToFloat": Input,
        "Cast": Input,
        "image_tensor": Input,
        "Postprocessor": NMS,
        "concat": concat_box_loc,
        "concat_1": concat_box_conf,
        "Concatenate": concat_priorbox,
        "MultipleGridAnchorGenerator/Concatenate": concat_priorbox,
        "SecondStagePostprocessor": NMS
    }
    graph.collapse_namespaces(namespace_map)

    graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False)
    if graph.find_nodes_by_op("NMS_TRT"):
        if "Input" in graph.find_nodes_by_op("NMS_TRT")[0].input:
            graph.find_nodes_by_op("NMS_TRT")[0].input.remove("Input")
    if "image_tensor:0" in graph.find_nodes_by_name("Input")[0].input:
        graph.find_nodes_by_name("Input")[0].input.remove("image_tensor:0")
    if "image_tensor" in graph.find_nodes_by_name("Input")[0].input:
        graph.find_nodes_by_name("Input")[0].input.remove("image_tensor")
    if graph.find_nodes_by_name("ToFloat_3"):
        if "image_tensor:0" in graph.find_nodes_by_name("ToFloat_3")[0].input:
            graph.find_nodes_by_name("ToFloat_3")[0].input.remove(
                "image_tensor:0")

    return graph