q_stats.update({node: (args.q_mean[idx], args.q_std[idx])}) converter.quantized_input_stats = q_stats tflite_model = converter.convert() open(args.output_dir + ".tflite", "wb").write(tflite_model) print("Model successfully converted to tflite flatbuffer") # Compile the flatbuffer for edge TPU subprocess.run(["edgetpu_compiler", args.output_dir + ".tflite"], check=True) print("Model successfully compiled") if __name__ == '__main__': parser = argparse.ArgumentParser() setup_args(parser) add_edgetpu_args(parser) args = parser.parse_args() # Create graph with tf.compat.v1.gfile.GFile(args.input, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) # Get graph data graph_chars = converter_util.GraphCharacteristics(graph_def) # Set input dimensions input_dims = converter_util.get_input_dims(args, graph_chars.input_nodes[0]) # TODO: Only supports one input tensor convert_to_edgetpu(args, input_dims, graph_chars=graph_chars)
def convert_to_edgetpu(args, input_dims, graph_chars=None): if len(args.q_mean) != len(args.q_std): print("Error: Number of q_mean arguments ({}) not equal to number of q_std arguments ({})" .format(len(args.q_mean), len(args.q_std))) return if graph_chars is None: with tf.compat.v1.gfile.GFile(args.input, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) graph_chars = converter_util.GraphCharacteristics(graph_def) if len(graph_chars.input_nodes) != len(args.q_mean): print("Error: Number of input nodes ({}) not equal to number of quantization parameters ({})" .format(len(graph_chars.input_nodes), len(args.q_mean))) return output_nodes = [] # Correct for multiple output dimensions. if "TFLite_Detection_PostProcess" in graph_chars.output_node_names: output_nodes = ["TFLite_Detection_PostProcess", "TFLite_Detection_PostProcess:1", "TFLite_Detection_PostProcess:2", "TFLite_Detection_PostProcess:3"] elif any([node.attr['_output_types'] is not None for node in graph_chars.output_nodes]): for node in graph_chars.output_nodes: num_out = len(node.attr['_output_types'].list.type) if num_out > 0: print("Node {} has {} output dimensions".format(node.name, num_out)) output_nodes = [node.name + ':' + str(i) for i in range(num_out)] output_nodes[0] = node.name else: output_nodes = graph_chars.output_nodes print("Corrected output names: ", output_nodes) # Check for quantization quantized = False for node in output_nodes: if graph_chars.nodes_by_name[node.split(':')[0]].attr['_output_quantized'].b: print("Quantization detected. Using quantized conversion with means and STDs ({}, {})" .format(args.q_mean, args.q_std)) quantized = True break # Convert and save model converter = tf.lite.TFLiteConverter.from_frozen_graph(args.input, graph_chars.input_node_names, output_nodes, input_shapes={graph_chars.input_node_names[0]: input_dims}) converter.allow_custom_ops = True if quantized: converter.inference_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8 # TODO: Fix assumption that quantization is 8-bit converter.inference_input_type = tf.compat.v1.lite.constants.QUANTIZED_UINT8 q_stats = {} for idx, node in enumerate(graph_chars.input_node_names): q_stats.update({node: (args.q_mean[idx], args.q_std[idx])}) converter.quantized_input_stats = q_stats tflite_model = converter.convert() open(args.output_dir + ".tflite", "wb").write(tflite_model) print("Model successfully converted to tflite flatbuffer") # Compile the flatbuffer for edge TPU subprocess.run(["edgetpu_compiler", args.output_dir + ".tflite"], check=True) print("Model successfully compiled")
def add_plugin(graph, input_dims, graph_chars=None): graph_def = graph.as_graph_def() if graph_chars is None: graph_chars = converter_util.GraphCharacteristics(graph_def) num_classes = converter_util.get_num_classes(graph_def, graph_chars=graph_chars) input_order = converter_util.get_NMS_input_order(graph_def, "Postprocessor", graph_chars=graph_chars) if any(x == -1 for x in input_order): print("NMS input order error: {} Aborting".format(input_order)) exit(1) if args.debug: print("Detected number of classes: ", num_classes) print("Detected NMS input order: ", input_order) assert_nodes = graph.find_nodes_by_op("Assert") graph.remove(assert_nodes, remove_exclusive_dependencies=True) identity_nodes = graph.find_nodes_by_op("Identity") graph.forward_inputs(identity_nodes) Input = gs.create_plugin_node(name="Input", op="Placeholder", shape=(1, ) + input_dims) # TODO: Consider automation of parameters PriorBox = gs.create_plugin_node(name="MultipleGridAnchorGenerator", op="GridAnchor_TRT", minSize=0.2, maxSize=0.95, aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33], variance=[0.1, 0.1, 0.2, 0.2], featureMapShapes=[19, 10, 5, 3, 2, 1], numLayers=6) NMS = gs.create_plugin_node(name="NMS", op="NMS_TRT", shareLocation=1, varianceEncodedInTarget=0, backgroundLabelId=0, confidenceThreshold=0.3, nmsThreshold=0.6, topK=100, keepTopK=100, numClasses=num_classes, inputOrder=input_order, confSigmoid=1, isNormalized=1) concat_box_loc = gs.create_plugin_node("concat_box_loc", op="FlattenConcat_TRT", axis=1, ignoreBatch=0) concat_box_conf = gs.create_plugin_node("concat_box_conf", op="FlattenConcat_TRT", axis=1, ignoreBatch=0) concat_priorbox = gs.create_node("concat_priorbox", op="ConcatV2", axis=2) namespace_map = { "MultipleGridAnchorGenerator": PriorBox, "Preprocessor": Input, "ToFloat": Input, "Cast": Input, "image_tensor": Input, "Postprocessor": NMS, "concat": concat_box_loc, "concat_1": concat_box_conf, "Concatenate": concat_priorbox, "MultipleGridAnchorGenerator/Concatenate": concat_priorbox, "SecondStagePostprocessor": NMS } graph.collapse_namespaces(namespace_map) graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False) if graph.find_nodes_by_op("NMS_TRT"): if "Input" in graph.find_nodes_by_op("NMS_TRT")[0].input: graph.find_nodes_by_op("NMS_TRT")[0].input.remove("Input") if "image_tensor:0" in graph.find_nodes_by_name("Input")[0].input: graph.find_nodes_by_name("Input")[0].input.remove("image_tensor:0") if "image_tensor" in graph.find_nodes_by_name("Input")[0].input: graph.find_nodes_by_name("Input")[0].input.remove("image_tensor") if graph.find_nodes_by_name("ToFloat_3"): if "image_tensor:0" in graph.find_nodes_by_name("ToFloat_3")[0].input: graph.find_nodes_by_name("ToFloat_3")[0].input.remove( "image_tensor:0") return graph