示例#1
0
 def convert_tr(self, graph_def, fetches):
     """Convert to TensorRT."""
     from tensorflow.python.compiler.tensorrt import trt  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
     converter = trt.TrtGraphConverter(
         nodes_blacklist=[t.split(':')[0] for t in fetches],
         input_graph_def=graph_def,
         precision_mode=self.tensorrt)
     infer_graph = converter.convert()
     goutput = tf.import_graph_def(infer_graph, return_elements=fetches)
     return goutput
示例#2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help='tf-trt model.', required=True)
    args = parser.parse_args()

    checkpoint_path, num_classes = download_classification_checkpoint(
        args.model, 'data')
    frozen_graph, input_names, output_names = build_classification_graph(
        model=args.model,
        checkpoint=checkpoint_path,
        num_classes=num_classes,
        is_remove_relu6=True)
    print(input_names, output_names, num_classes)

    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph,
        nodes_blacklist=output_names,  #output nodes
        max_batch_size=1,
        is_dynamic_op=False,
        max_workspace_size_bytes=1 << 25,
        precision_mode=trt.TrtPrecisionMode.
        FP16,  # trt.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES
        minimum_segment_size=50)
    trt_graph = converter.convert()
    # trt_graph = trt.create_inference_graph(
    #     input_graph_def=frozen_graph,
    #     outputs=output_names,
    #     max_batch_size=1,
    #     max_workspace_size_bytes=1 << 25,
    #     precision_mode='FP16',
    #     minimum_segment_size=3
    # )

    trt_engine_opts = len(
        [1 for n in trt_graph.node if str(n.op) == 'TRTEngineOp'])
    print("trt_engine_opts = {}".format(trt_engine_opts))

    model_dir = os.path.join('.', 'model')
    if tf.gfile.Exists(model_dir) == False:
        tf.gfile.MkDir(model_dir)

    base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
    save_model_file_name = base_name + '_frozen_fp16.pb'
    with open(os.path.join(model_dir, save_model_file_name), 'wb') as f:
        f.write(trt_graph.SerializeToString())
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help='tf-trt model.')
    parser.add_argument('--path', help='path to checkpoint dir.')
    parser.add_argument('--output', help='Output dir.', default='model')
    parser.add_argument('--force_nms_cpu', help='Force NMS CPU', action='store_true')
    parser.add_argument('--threshold', help='Score threshold', default=0.5, type=float)
    args = parser.parse_args()

    model_dir = args.output
    if tf.gfile.Exists(model_dir) == False:
        tf.gfile.MkDir(model_dir)

    if args.model:
        config_path, checkpoint_path = download_detection_model(args.model, 'data')

    elif args.path:
        if tf.gfile.Exists(args.path) == False:
            print('Error: Checkpoint dir dose note exist!')
            return

        config_path = os.path.join(args.path, 'pipeline.config')
        checkpoint_path = os.path.join(args.path,'model.ckpt')

    else:
        print('Error: Either model or path is not specified in the argument.')
        return

    frozen_graph, input_names, output_names = build_detection_graph(
        config=config_path,
        force_nms_cpu=args.force_nms_cpu,
        checkpoint=checkpoint_path,
        batch_size=1
    )
    print(input_names, output_names)
    base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
    save_model_file_name = base_name + '_frozen.pb'
    with open(os.path.join(model_dir, save_model_file_name), 'wb') as f:
        f.write(frozen_graph.SerializeToString())

    # base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
    # save_model_file_name = base_name + '_frozen.pb'
    # with open(os.path.join(model_dir, save_model_file_name), 'wb') as f:
    #     f.write(frozen_graph.SerializeToString())

    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph,
        nodes_blacklist=output_names, #output nodes
        max_batch_size=1,
        is_dynamic_op=False,
        max_workspace_size_bytes = 1 << 25,
        precision_mode=trt.TrtPrecisionMode.FP16, # trt.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES
        minimum_segment_size=50)
    trt_graph = converter.convert()
    # trt_graph = trt.create_inference_graph(
    #     input_graph_def=frozen_graph,
    #     outputs=output_names,
    #     max_batch_size=1,
    #     max_workspace_size_bytes=1 << 25,
    #     precision_mode='FP16',
    #     minimum_segment_size=3
    # )

    trt_engine_opts = len([1 for n in trt_graph.node if str(n.op) == 'TRTEngineOp'])
    print("trt_engine_opts = {}".format(trt_engine_opts))

    base_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
    save_model_file_name = base_name + '_frozen_fp16.pb'
    with open(os.path.join(model_dir, save_model_file_name), 'wb') as f:
        f.write(trt_graph.SerializeToString())