コード例 #1
0
    def __init__(self, args):
        self.args = args
        self.det_algorithm = args.det_algorithm
        self.use_zero_copy_run = args.use_zero_copy_run
        pre_process_list = [{
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
                'limit_type': args.det_limit_type
            }
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
        postprocess_params = {}
        if self.det_algorithm == "DB":
            postprocess_params['name'] = 'DBPostProcess'
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
            postprocess_params["use_dilation"] = True
        elif self.det_algorithm == "EAST":
            postprocess_params['name'] = 'EASTPostProcess'
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
            pre_process_list[0] = {
                'DetResizeForTest': {'resize_long': args.det_limit_side_len}
            }
            postprocess_params['name'] = 'SASTPostProcess'
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)

        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
            args, 'det', logger)  # paddle.jit.load(args.det_model_dir)
コード例 #2
0
 def __init__(self, args):
     self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
     self.character_type = args.rec_char_type
     self.rec_batch_num = args.rec_batch_num
     self.rec_algorithm = args.rec_algorithm
     self.max_text_length = args.max_text_length
     postprocess_params = {
         'name': 'CTCLabelDecode',
         "character_type": args.rec_char_type,
         "character_dict_path": args.rec_char_dict_path,
         "use_space_char": args.use_space_char
     }
     if self.rec_algorithm == "SRN":
         postprocess_params = {
             'name': 'SRNLabelDecode',
             "character_type": args.rec_char_type,
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     elif self.rec_algorithm == "RARE":
         postprocess_params = {
             'name': 'AttnLabelDecode',
             "character_type": args.rec_char_type,
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     self.postprocess_op = build_post_process(postprocess_params)
     self.predictor, self.input_tensor, self.output_tensors = \
         utility.create_predictor(args, 'rec', logger)
コード例 #3
0
 def __init__(self, args):
     if args.use_pdserving is False:
         self.predictor, self.input_tensor, self.output_tensors =\
             utility.create_predictor(args, mode="rec")
         self.use_zero_copy_run = args.use_zero_copy_run
     self.rec_image_shape = [
         int(v) for v in args.rec_image_shape.split(",")
     ]
     self.character_type = args.rec_char_type
     self.rec_batch_num = args.rec_batch_num
     self.rec_algorithm = args.rec_algorithm
     self.text_len = args.max_text_length
     char_ops_params = {
         "character_type": args.rec_char_type,
         "character_dict_path": args.rec_char_dict_path,
         "use_space_char": args.use_space_char,
         "max_text_length": args.max_text_length
     }
     if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
         char_ops_params['loss_type'] = 'ctc'
         self.loss_type = 'ctc'
     elif self.rec_algorithm == "RARE":
         char_ops_params['loss_type'] = 'attention'
         self.loss_type = 'attention'
     elif self.rec_algorithm == "SRN":
         char_ops_params['loss_type'] = 'srn'
         self.loss_type = 'srn'
     self.char_ops = CharacterOps(char_ops_params)
コード例 #4
0
    def __init__(self, args):
        pre_process_list = [{
            'ResizeTableImage': {
                'max_len': args.table_max_len
            }
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'PaddingTableImage': None
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image']
            }
        }]
        postprocess_params = {
            'name': 'TableLabelDecode',
            "character_type": args.table_char_type,
            "character_dict_path": args.table_char_dict_path,
        }

        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
            utility.create_predictor(args, 'table', logger)
コード例 #5
0
    def __init__(self, args):
        max_side_len = args.det_max_side_len
        self.det_algorithm = args.det_algorithm
        preprocess_params = {'max_side_len': max_side_len}
        postprocess_params = {}
        if self.det_algorithm == "DB":
            self.preprocess_op = DBProcessTest(preprocess_params)
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
            self.postprocess_op = DBPostProcess(postprocess_params)
        elif self.det_algorithm == "EAST":
            self.preprocess_op = EASTProcessTest(preprocess_params)
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
            self.postprocess_op = EASTPostPocess(postprocess_params)
        elif self.det_algorithm == "SAST":
            self.preprocess_op = SASTProcessTest(preprocess_params)
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            postprocess_params["sample_pts_num"] = args.det_sast_sample_pts_num
            postprocess_params["expand_scale"] = args.det_sast_expand_scale
            postprocess_params["shrink_ratio_of_width"] = args.det_sast_shrink_ratio_of_width
            self.postprocess_op = SASTPostProcess(postprocess_params)
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)

        self.predictor, self.input_tensor, self.output_tensors =\
            utility.create_predictor(args, mode="det")
コード例 #6
0
 def __init__(self, args):
     if args.use_pdserving is False:
         self.predictor, self.input_tensor, self.output_tensors = \
             utility.create_predictor(args, mode="cls")
         self.use_zero_copy_run = args.use_zero_copy_run
     self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
     self.cls_batch_num = args.rec_batch_num
     self.label_list = args.label_list
     self.cls_thresh = args.cls_thresh
コード例 #7
0
ファイル: predict_rec.py プロジェクト: ioracion/PaddleOCR
 def __init__(self, args):
     self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
     self.rec_batch_num = args.rec_batch_num
     self.rec_algorithm = args.rec_algorithm
     postprocess_params = {
         'name': 'CTCLabelDecode',
         "character_dict_path": args.rec_char_dict_path,
         "use_space_char": args.use_space_char
     }
     if self.rec_algorithm == "SRN":
         postprocess_params = {
             'name': 'SRNLabelDecode',
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     elif self.rec_algorithm == "RARE":
         postprocess_params = {
             'name': 'AttnLabelDecode',
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     elif self.rec_algorithm == 'NRTR':
         postprocess_params = {
             'name': 'NRTRLabelDecode',
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     elif self.rec_algorithm == "SAR":
         postprocess_params = {
             'name': 'SARLabelDecode',
             "character_dict_path": args.rec_char_dict_path,
             "use_space_char": args.use_space_char
         }
     self.postprocess_op = build_post_process(postprocess_params)
     self.predictor, self.input_tensor, self.output_tensors, self.config = \
         utility.create_predictor(args, 'rec', logger)
     self.benchmark = args.benchmark
     self.use_onnx = args.use_onnx
     if args.benchmark:
         import auto_log
         pid = os.getpid()
         gpu_id = utility.get_infer_gpuid()
         self.autolog = auto_log.AutoLogger(
             model_name="rec",
             model_precision=args.precision,
             batch_size=args.rec_batch_num,
             data_shape="dynamic",
             save_path=None,  #args.save_log_path,
             inference_config=self.config,
             pids=pid,
             process_name=None,
             gpu_ids=gpu_id if args.use_gpu else None,
             time_keys=[
                 'preprocess_time', 'inference_time', 'postprocess_time'
             ],
             warmup=0,
             logger=logger)
コード例 #8
0
 def __init__(self, args):
     self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
     self.cls_batch_num = args.cls_batch_num
     self.cls_thresh = args.cls_thresh
     postprocess_params = {
         'name': 'ClsPostProcess',
         "label_list": args.label_list,
     }
     self.postprocess_op = build_post_process(postprocess_params)
     self.predictor, self.input_tensor, self.output_tensors = \
         utility.create_predictor(args, 'cls', logger)
コード例 #9
0
ファイル: predict_rec.py プロジェクト: zhypopt/PaddleOCR
 def __init__(self, args):
     self.rec_image_shape = [
         int(v) for v in args.rec_image_shape.split(",")
     ]
     self.character_type = args.rec_char_type
     self.rec_batch_num = args.rec_batch_num
     self.rec_algorithm = args.rec_algorithm
     self.use_zero_copy_run = args.use_zero_copy_run
     postprocess_params = {
         'name': 'CTCLabelDecode',
         "character_type": args.rec_char_type,
         "character_dict_path": args.rec_char_dict_path,
         "use_space_char": args.use_space_char
     }
     self.postprocess_op = build_post_process(postprocess_params)
     self.predictor, self.input_tensor, self.output_tensors = \
         utility.create_predictor(args, 'rec', logger)
コード例 #10
0
 def __init__(self, args):
     self.predictor, self.input_tensor, self.output_tensors =\
         utility.create_predictor(args, mode="rec")
     image_shape = [int(v) for v in args.rec_image_shape.split(",")]
     self.rec_image_shape = image_shape
     self.character_type = args.rec_char_type
     self.rec_batch_num = args.rec_batch_num
     self.rec_algorithm = args.rec_algorithm
     char_ops_params = {}
     char_ops_params["character_type"] = args.rec_char_type
     char_ops_params["character_dict_path"] = args.rec_char_dict_path
     if self.rec_algorithm != "RARE":
         char_ops_params['loss_type'] = 'ctc'
         self.loss_type = 'ctc'
     else:
         char_ops_params['loss_type'] = 'attention'
         self.loss_type = 'attention'
     self.char_ops = CharacterOps(char_ops_params)
コード例 #11
0
 def __init__(self, args):
     max_side_len = args.det_max_side_len
     self.det_algorithm = args.det_algorithm
     preprocess_params = {
         'test_image_shape': [640, 640],
         'max_side_len': max_side_len
     }
     postprocess_params = {}
     if self.det_algorithm == "DB":
         self.preprocess_op = DBProcessTest(preprocess_params)
         postprocess_params["thresh"] = args.det_db_thresh
         postprocess_params["box_thresh"] = args.det_db_box_thresh
         postprocess_params["max_candidates"] = 1000
         postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
         self.postprocess_op = DBPostProcess(postprocess_params)
     elif self.det_algorithm == "EAST":
         self.preprocess_op = EASTProcessTest(preprocess_params)
         postprocess_params["score_thresh"] = args.det_east_score_thresh
         postprocess_params["cover_thresh"] = args.det_east_cover_thresh
         postprocess_params["nms_thresh"] = args.det_east_nms_thresh
         self.postprocess_op = EASTPostPocess(postprocess_params)
     elif self.det_algorithm == "SAST":
         self.preprocess_op = SASTProcessTest(preprocess_params)
         postprocess_params["score_thresh"] = args.det_sast_score_thresh
         postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
         self.det_sast_polygon = args.det_sast_polygon
         if self.det_sast_polygon:
             postprocess_params["sample_pts_num"] = 6
             postprocess_params["expand_scale"] = 1.2
             postprocess_params["shrink_ratio_of_width"] = 0.2
         else:
             postprocess_params["sample_pts_num"] = 2
             postprocess_params["expand_scale"] = 1.0
             postprocess_params["shrink_ratio_of_width"] = 0.3
         self.postprocess_op = SASTPostProcess(postprocess_params)
     else:
         logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
         sys.exit(0)
     if args.use_pdserving is False:
         self.use_zero_copy_run = args.use_zero_copy_run
         self.predictor, self.input_tensor, self.output_tensors =\
             utility.create_predictor(args, mode="det")
コード例 #12
0
    def __init__(self, args):
        self.args = args
        self.e2e_algorithm = args.e2e_algorithm
        pre_process_list = [{
            'E2EResizeForTest': {}
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
        postprocess_params = {}
        if self.e2e_algorithm == "PGNet":
            pre_process_list[0] = {
                'E2EResizeForTest': {
                    'max_side_len': args.e2e_limit_side_len,
                    'valid_set': 'totaltext'
                }
            }
            postprocess_params['name'] = 'PGPostProcess'
            postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
            postprocess_params["character_dict_path"] = args.e2e_char_dict_path
            postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
            postprocess_params["mode"] = args.e2e_pgnet_mode
            self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
        else:
            logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
            sys.exit(0)

        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
            args, 'e2e', logger)  # paddle.jit.load(args.det_model_dir)
コード例 #13
0
ファイル: predict_det.py プロジェクト: ruyueshi/PaddleOCR
    def __init__(self, args):
        self.args = args
        self.det_algorithm = args.det_algorithm
        pre_process_list = [{
            'DetResizeForTest': {
                'limit_side_len': args.det_limit_side_len,
                'limit_type': args.det_limit_type,
            }
        }, {
            'NormalizeImage': {
                'std': [0.229, 0.224, 0.225],
                'mean': [0.485, 0.456, 0.406],
                'scale': '1./255.',
                'order': 'hwc'
            }
        }, {
            'ToCHWImage': None
        }, {
            'KeepKeys': {
                'keep_keys': ['image', 'shape']
            }
        }]
        postprocess_params = {}
        if self.det_algorithm == "DB":
            postprocess_params['name'] = 'DBPostProcess'
            postprocess_params["thresh"] = args.det_db_thresh
            postprocess_params["box_thresh"] = args.det_db_box_thresh
            postprocess_params["max_candidates"] = 1000
            postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
            postprocess_params["use_dilation"] = args.use_dilation
            postprocess_params["score_mode"] = args.det_db_score_mode
        elif self.det_algorithm == "EAST":
            postprocess_params['name'] = 'EASTPostProcess'
            postprocess_params["score_thresh"] = args.det_east_score_thresh
            postprocess_params["cover_thresh"] = args.det_east_cover_thresh
            postprocess_params["nms_thresh"] = args.det_east_nms_thresh
        elif self.det_algorithm == "SAST":
            pre_process_list[0] = {
                'DetResizeForTest': {
                    'resize_long': args.det_limit_side_len
                }
            }
            postprocess_params['name'] = 'SASTPostProcess'
            postprocess_params["score_thresh"] = args.det_sast_score_thresh
            postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
            self.det_sast_polygon = args.det_sast_polygon
            if self.det_sast_polygon:
                postprocess_params["sample_pts_num"] = 6
                postprocess_params["expand_scale"] = 1.2
                postprocess_params["shrink_ratio_of_width"] = 0.2
            else:
                postprocess_params["sample_pts_num"] = 2
                postprocess_params["expand_scale"] = 1.0
                postprocess_params["shrink_ratio_of_width"] = 0.3
        else:
            logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
            sys.exit(0)

        self.preprocess_op = create_operators(pre_process_list)
        self.postprocess_op = build_post_process(postprocess_params)
        self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
            args, 'det', logger)

        if args.benchmark:
            import auto_log
            pid = os.getpid()
            self.autolog = auto_log.AutoLogger(model_name="det",
                                               model_precision=args.precision,
                                               batch_size=1,
                                               data_shape="dynamic",
                                               save_path=args.save_log_path,
                                               inference_config=self.config,
                                               pids=pid,
                                               process_name=None,
                                               gpu_ids=0,
                                               time_keys=[
                                                   'preprocess_time',
                                                   'inference_time',
                                                   'postprocess_time'
                                               ],
                                               warmup=10)