def main(args):
    if use_auto_tune(args):
        dataset = get_dataset(args)
        tune_img_nums = 10
        auto_tune(args, dataset, tune_img_nums)

    predictor = DatasetPredictor(args)
    predictor.run_dataset()

    if use_auto_tune(args) and \
        os.path.exists(args.auto_tuned_shape_file):
        os.remove(args.auto_tuned_shape_file)
def main(args):
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    if use_auto_tune(args):
        auto_tune(args, args.image_path, 1)

    predictor = PredictorBenchmark(args)
    predictor.run(args.image_path)

    if use_auto_tune(args) and \
        os.path.exists(args.auto_tuned_shape_file):
        os.remove(args.auto_tuned_shape_file)
def auto_tune(args, dataset, img_nums):
    """
    Use images to auto tune the dynamic shape for trt sub graph.
    The tuned shape saved in args.auto_tuned_shape_file.

    Args:
        args(dict): input args.
        dataset(dataset): an dataset.
        img_nums(int): the nums of images used for auto tune.
    Returns:
        None
    """
    logger.info("Auto tune the dynamic shape for GPU TRT.")

    assert use_auto_tune(args)

    num = min(len(dataset), img_nums)

    cfg = DeployConfig(args.cfg)
    pred_cfg = PredictConfig(cfg.model, cfg.params)
    pred_cfg.enable_use_gpu(100, 0)
    if not args.print_detail:
        pred_cfg.disable_glog_info()
    pred_cfg.collect_shape_range_info(args.auto_tuned_shape_file)

    predictor = create_predictor(pred_cfg)
    input_names = predictor.get_input_names()
    input_handle = predictor.get_input_handle(input_names[0])

    for idx, (img, _) in enumerate(dataset):
        data = np.array([img])
        input_handle.reshape(data.shape)
        input_handle.copy_from_cpu(data)
        try:
            predictor.run()
        except:
            logger.info(
                "Auto tune fail. Usually, the error is out of GPU memory, "
                "because the model and image is too large. \n")
            del predictor
            if os.path.exists(args.auto_tuned_shape_file):
                os.remove(args.auto_tuned_shape_file)
            return

        if idx + 1 >= num:
            break

    logger.info("Auto tune success.\n")