def main(config, device, logger, vdl_writer): global_config = config['Global'] # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model if hasattr(post_process_class, 'character'): config['Architecture']["Head"]['out_channels'] = len( getattr(post_process_class, 'character')) model = build_model(config['Architecture']) load_model(config, model) # create data ops transforms = [] use_padding = False for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue if op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image'] if op_name == "ResizeTableImage": use_padding = True padding_max_len = op['ResizeTableImage']['max_len'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) model.eval() for file in get_image_file_list(config['Global']['infer_img']): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} batch = transform(data, ops) images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds) res_html_code = post_result['res_html_code'] res_loc = post_result['res_loc'] img = cv2.imread(file) imgh, imgw = img.shape[0:2] res_loc_final = [] for rno in range(len(res_loc[0])): x0, y0, x1, y1 = res_loc[0][rno] left = max(int(imgw * x0), 0) top = max(int(imgh * y0), 0) right = min(int(imgw * x1), imgw - 1) bottom = min(int(imgh * y1), imgh - 1) cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) res_loc_final.append([left, top, right, bottom]) res_loc_str = json.dumps(res_loc_final) logger.info("result: {}, {}".format(res_html_code, res_loc_final)) logger.info("success!")
def main(): global_config = config['Global'] # build model model = build_model(config['Architecture']) load_model(config, model) # create data ops transforms = [] for op in config['Eval']['dataset']['transforms']: transforms.append(op) data_dir = config['Eval']['dataset']['data_dir'] ops = create_operators(transforms, global_config) save_res_path = config['Global']['save_res_path'] class_path = config['Global']['class_path'] idx_to_cls = read_class_list(class_path) if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() warmup_times = 0 count_t = [] with open(save_res_path, "wb") as fout: with open(config['Global']['infer_img'], "rb") as f: lines = f.readlines() for index, data_line in enumerate(lines): if index == 10: warmup_t = time.time() data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split("\t") img_path, label = data_dir + "/" + substr[0], substr[1] data = {'img_path': img_path, 'label': label} with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img st = time.time() batch = transform(data, ops) batch_pred = [0] * len(batch) for i in range(len(batch)): batch_pred[i] = paddle.to_tensor( np.expand_dims(batch[i], axis=0)) st = time.time() node, edge = model(batch_pred) node = F.softmax(node, -1) count_t.append(time.time() - st) draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") logger.info("It took {} s for predict {} images.".format( np.sum(count_t), len(count_t))) ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:]) logger.info("The ips is {} images/s".format(ips))
def main(): global_config = config['Global'] # build model model = build_model(config['Architecture']) load_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # create data ops transforms = [] for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image', 'shape'] transforms.append(op) ops = create_operators(transforms, global_config) save_res_path = config['Global']['save_res_path'] if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() with open(save_res_path, "wb") as fout: for file in get_image_file_list(config['Global']['infer_img']): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} batch = transform(data, ops) images = np.expand_dims(batch[0], axis=0) shape_list = np.expand_dims(batch[1], axis=0) images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) points, strs = post_result['points'], post_result['texts'] # write resule dt_boxes_json = [] for poly, str in zip(points, strs): tmp_json = {"transcription": str} tmp_json['points'] = poly.tolist() dt_boxes_json.append(tmp_json) otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) src_img = cv2.imread(file) draw_e2e_res(points, strs, config, src_img, file) logger.info("success!")
def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) merge_config(FLAGS.opt) logger = get_logger() # build post process post_process_class = build_post_process(config["PostProcess"], config["Global"]) # build model # for rec algorithm if hasattr(post_process_class, "character"): char_num = len(getattr(post_process_class, "character")) if config["Architecture"]["algorithm"] in ["Distillation", ]: # distillation model for key in config["Architecture"]["Models"]: config["Architecture"]["Models"][key]["Head"][ "out_channels"] = char_num # just one final tensor needs to to exported for inference config["Architecture"]["Models"][key][ "return_all_feats"] = False else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num model = build_model(config["Architecture"]) load_model(config, model) model.eval() save_path = config["Global"]["save_inference_dir"] arch_config = config["Architecture"] if arch_config["algorithm"] in ["Distillation", ]: # distillation model archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model(model.model_list[idx], archs[idx], sub_model_save_path, logger) else: save_path = os.path.join(save_path, "inference") export_single_model(model, arch_config, save_path, logger)
def main(): global_config = config['Global'] # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model model = build_model(config['Architecture']) load_model(config, model) # create data ops transforms = [] for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) model.eval() for file in get_image_file_list(config['Global']['infer_img']): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} batch = transform(data, ops) images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds) for rec_reuslt in post_result: logger.info('\t result: {}'.format(rec_reuslt)) logger.info("success!")
def main(): global_config = config['Global'] # build dataloader valid_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in [ "Distillation", ]: # distillation model for key in config['Architecture']["Models"]: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) extra_input = config['Architecture']['algorithm'] in [ "SRN", "NRTR", "SAR", "SEED" ] if "model_type" in config['Architecture'].keys(): model_type = config['Architecture']['model_type'] else: model_type = None best_model_dict = load_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): logger.info('{}:{}'.format(k, v)) # build metric eval_class = build_metric(config['Metric']) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, extra_input) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v))
def main(): global_config = config['Global'] # build dataloader config['Eval']['dataset']['name'] = config['Train']['dataset']['name'] config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][ 'data_dir'] config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ 'label_file_list'] eval_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num #set return_features = True config['Architecture']["Head"]["return_feats"] = True model = build_model(config['Architecture']) best_model_dict = load_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): logger.info('{}:{}'.format(k, v)) # get features from train data char_center = program.get_center(model, eval_dataloader, post_process_class) #serialize to disk with open("train_center.pkl", 'wb') as f: pickle.dump(char_center, f) return
def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() global_config = config['Global'] # build dataloader train_dataloader = build_dataloader(config, 'Train', device, logger) if len(train_dataloader) == 0: logger.error( "No Images in train dataset, please ensure\n" + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n" + "\t2. The annotation file and path in the configuration file are provided normally." ) return if config['Eval']: valid_dataloader = build_dataloader(config, 'Eval', device, logger) else: valid_dataloader = None # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in [ "Distillation", ]: # distillation model for key in config['Architecture']["Models"]: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model) # build loss loss_class = build_loss(config['Loss']) # build optim optimizer, lr_scheduler = build_optimizer( config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), parameters=model.parameters()) # build metric eval_class = build_metric(config['Metric']) # load pretrain model pre_best_model_dict = load_model(config, model, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( len(valid_dataloader))) use_amp = config["Global"].get("use_amp", False) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_max_inplace_grad_add': 8, } paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) scale_loss = config["Global"].get("scale_loss", 1.0) use_dynamic_loss_scaling = config["Global"].get( "use_dynamic_loss_scaling", False) scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) else: scaler = None # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() global_config = config['Global'] # build dataloader train_dataloader = build_dataloader(config, 'Train', device, logger) if config['Eval']: valid_dataloader = build_dataloader(config, 'Eval', device, logger) else: valid_dataloader = None # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) if config['Architecture']['model_type'] == 'det': input_shape = [1, 3, 640, 640] elif config['Architecture']['model_type'] == 'rec': input_shape = [1, 3, 32, 320] flops = paddle.flops(model, input_shape) logger.info("FLOPs before pruning: {}".format(flops)) from paddleslim.dygraph import FPGMFilterPruner model.train() pruner = FPGMFilterPruner(model, input_shape) # build loss loss_class = build_loss(config['Loss']) # build optim optimizer, lr_scheduler = build_optimizer( config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), parameters=model.parameters()) # build metric eval_class = build_metric(config['Metric']) # load pretrain model pre_best_model_dict = load_model(config, model, optimizer) logger.info( 'train dataloader has {} iters, valid dataloader has {} iters'.format( len(train_dataloader), len(valid_dataloader))) # build metric eval_class = build_metric(config['Metric']) logger.info( 'train dataloader has {} iters, valid dataloader has {} iters'.format( len(train_dataloader), len(valid_dataloader))) def eval_fn(): metric = program.eval(model, valid_dataloader, post_process_class, eval_class, False) if config['Architecture']['model_type'] == 'det': main_indicator = 'hmean' else: main_indicator = 'acc' logger.info("metric[{}]: {}".format(main_indicator, metric[main_indicator])) return metric[main_indicator] run_sensitive_analysis = False """ run_sensitive_analysis=True: Automatically compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in differenct pruned ratios. The sensitivities can be used to get a group of best ratios with some condition. run_sensitive_analysis=False: Set prune trim ratio to a fixed value, such as 10%. The larger the value, the more convolution weights will be cropped. """ if run_sensitive_analysis: params_sensitive = pruner.sensitive( eval_func=eval_fn, sen_file="./deploy/slim/prune/sen.pickle", skip_vars=[ "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" ]) logger.info( "The sensitivity analysis results of model parameters saved in sen.pickle" ) # calculate pruned params's ratio params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) for key in params_sensitive.keys(): logger.info("{}, {}".format(key, params_sensitive[key])) else: params_sensitive = {} for param in model.parameters(): if 'transpose' not in param.name and 'linear' not in param.name: # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped params_sensitive[param.name] = 0.1 plan = pruner.prune_vars(params_sensitive, [0]) flops = paddle.flops(model, input_shape) logger.info("FLOPs after pruning: {}".format(flops)) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer)
def main(config, device, logger, vdl_writer): global_config = config['Global'] # build dataloader valid_dataloader = build_dataloader(config, 'Eval', device, logger) # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) if config['Architecture']['model_type'] == 'det': input_shape = [1, 3, 640, 640] elif config['Architecture']['model_type'] == 'rec': input_shape = [1, 3, 32, 320] flops = paddle.flops(model, input_shape) logger.info("FLOPs before pruning: {}".format(flops)) from paddleslim.dygraph import FPGMFilterPruner model.train() pruner = FPGMFilterPruner(model, input_shape) # build metric eval_class = build_metric(config['Metric']) def eval_fn(): metric = program.eval(model, valid_dataloader, post_process_class, eval_class) if config['Architecture']['model_type'] == 'det': main_indicator = 'hmean' else: main_indicator = 'acc' logger.info("metric[{}]: {}".format(main_indicator, metric[ main_indicator])) return metric[main_indicator] params_sensitive = pruner.sensitive( eval_func=eval_fn, sen_file="./sen.pickle", skip_vars=[ "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" ]) logger.info( "The sensitivity analysis results of model parameters saved in sen.pickle" ) # calculate pruned params's ratio params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) for key in params_sensitive.keys(): logger.info("{}, {}".format(key, params_sensitive[key])) plan = pruner.prune_vars(params_sensitive, [0]) flops = paddle.flops(model, input_shape) logger.info("FLOPs after pruning: {}".format(flops)) # load pretrain model load_model(config, model) metric = program.eval(model, valid_dataloader, post_process_class, eval_class) if config['Architecture']['model_type'] == 'det': main_indicator = 'hmean' else: main_indicator = 'acc' logger.info("metric['']: {}".format(main_indicator, metric[main_indicator])) # start export model from paddle.jit import to_static infer_shape = [3, -1, -1] if config['Architecture']['model_type'] == "rec": infer_shape = [3, 32, -1] # for rec model, H must be 32 if 'Transform' in config['Architecture'] and config['Architecture'][ 'Transform'] is not None and config['Architecture'][ 'Transform']['name'] == 'TPS': logger.info( 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' ) infer_shape[-1] = 100 model = to_static( model, input_spec=[ paddle.static.InputSpec( shape=[None] + infer_shape, dtype='float32') ]) save_path = '{}/inference'.format(config['Global']['save_inference_dir']) paddle.jit.save(model, save_path) logger.info('inference model is saved to {}'.format(save_path))
def main(): ############################################################################################################ # 1. quantization configs ############################################################################################################ quant_config = { # weight preprocess type, default is None and no preprocessing is performed. 'weight_preprocess_type': None, # activation preprocess type, default is None and no preprocessing is performed. 'activation_preprocess_type': None, # weight quantize type, default is 'channel_wise_abs_max' 'weight_quantize_type': 'channel_wise_abs_max', # activation quantize type, default is 'moving_average_abs_max' 'activation_quantize_type': 'moving_average_abs_max', # weight quantize bit num, default is 8 'weight_bits': 8, # activation quantize bit num, default is 8 'activation_bits': 8, # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. default is 10000 'window_size': 10000, # The decay coefficient of moving average, default is 0.9 'moving_rate': 0.9, # for dygraph quantization, layers of type in quantizable_layer_type will be quantized 'quantizable_layer_type': ['Conv2D', 'Linear'], } FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) merge_config(FLAGS.opt) logger = get_logger() # build post process post_process_class = build_post_process(config['PostProcess'], config['Global']) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in [ "Distillation", ]: # distillation model for key in config['Architecture']["Models"]: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) # get QAT model quanter = QAT(config=quant_config) quanter.quantize(model) load_model(config, model) model.eval() # build metric eval_class = build_metric(config['Metric']) # build dataloader valid_dataloader = build_dataloader(config, 'Eval', device, logger) use_srn = config['Architecture']['algorithm'] == "SRN" model_type = config['Architecture'].get('model_type', None) # start eval metric = program.eval(model, valid_dataloader, post_process_class, eval_class, model_type, use_srn) logger.info('metric eval ***************') for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640] save_path = config["Global"]["save_inference_dir"] arch_config = config["Architecture"] if arch_config["algorithm"] in [ "Distillation", ]: # distillation model for idx, name in enumerate(model.model_name_list): sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model(quanter, model.model_list[idx], infer_shape, sub_model_save_path, logger) else: save_path = os.path.join(save_path, "inference") export_single_model(quanter, model, infer_shape, save_path, logger)
def main(): global_config = config['Global'] # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in [ "Distillation", ]: # distillation model for key in config['Architecture']["Models"]: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) load_model(config, model) # create data ops transforms = [] for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue elif op_name in ['RecResizeImg']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': if config['Architecture']['algorithm'] == "SRN": op[op_name]['keep_keys'] = [ 'image', 'encoder_word_pos', 'gsrm_word_pos', 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' ] elif config['Architecture']['algorithm'] == "SAR": op[op_name]['keep_keys'] = ['image', 'valid_ratio'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) save_res_path = config['Global'].get('save_res_path', "./output/rec/predicts_rec.txt") if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() with open(save_res_path, "w") as fout: for file in get_image_file_list(config['Global']['infer_img']): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} batch = transform(data, ops) if config['Architecture']['algorithm'] == "SRN": encoder_word_pos_list = np.expand_dims(batch[1], axis=0) gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) others = [ paddle.to_tensor(encoder_word_pos_list), paddle.to_tensor(gsrm_word_pos_list), paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(gsrm_slf_attn_bias2_list) ] if config['Architecture']['algorithm'] == "SAR": valid_ratio = np.expand_dims(batch[-1], axis=0) img_metas = [paddle.to_tensor(valid_ratio)] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": preds = model(images, others) elif config['Architecture']['algorithm'] == "SAR": preds = model(images, img_metas) else: preds = model(images) post_result = post_process_class(preds) info = None if isinstance(post_result, dict): rec_info = dict() for key in post_result: if len(post_result[key][0]) >= 2: rec_info[key] = { "label": post_result[key][0][0], "score": float(post_result[key][0][1]), } info = json.dumps(rec_info) else: if len(post_result[0]) >= 2: info = post_result[0][0] + "\t" + str(post_result[0][1]) if info is not None: logger.info("\t result: {}".format(info)) fout.write(os.path.basename(file) + "\t" + info + "\n") logger.info("success!")
def main(): global_config = config['Global'] # build model model = build_model(config['Architecture']) load_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) # create data ops transforms = [] for op in config['Eval']['dataset']['transforms']: op_name = list(op)[0] if 'Label' in op_name: continue elif op_name == 'KeepKeys': op[op_name]['keep_keys'] = ['image', 'shape'] transforms.append(op) ops = create_operators(transforms, global_config) save_res_path = config['Global']['save_res_path'] if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() with open(save_res_path, "wb") as fout: for file in get_image_file_list(config['Global']['infer_img']): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} batch = transform(data, ops) images = np.expand_dims(batch[0], axis=0) shape_list = np.expand_dims(batch[1], axis=0) images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) src_img = cv2.imread(file) dt_boxes_json = [] # parser boxes if post_result is dict if isinstance(post_result, dict): det_box_json = {} for k in post_result.keys(): boxes = post_result[k][0]['points'] dt_boxes_list = [] for box in boxes: tmp_json = {"transcription": ""} tmp_json['points'] = box.tolist() dt_boxes_list.append(tmp_json) det_box_json[k] = dt_boxes_list save_det_path = os.path.dirname( config['Global'] ['save_res_path']) + "/det_results_{}/".format(k) draw_det_res(boxes, config, src_img, file, save_det_path) else: boxes = post_result[0]['points'] dt_boxes_json = [] # write result for box in boxes: tmp_json = {"transcription": ""} tmp_json['points'] = box.tolist() dt_boxes_json.append(tmp_json) save_det_path = os.path.dirname( config['Global']['save_res_path']) + "/det_results/" draw_det_res(boxes, config, src_img, file, save_det_path) otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) logger.info("success!")
def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() global_config = config['Global'] # build dataloader train_dataloader = build_dataloader(config, 'Train', device, logger) if config['Eval']: valid_dataloader = build_dataloader(config, 'Eval', device, logger) else: valid_dataloader = None # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in [ "Distillation", ]: # distillation model for key in config['Architecture']["Models"]: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) if config['Global']['distributed']: model = paddle.DataParallel(model) # build loss loss_class = build_loss(config['Loss']) # build optim optimizer, lr_scheduler = build_optimizer( config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), parameters=model.parameters()) # build metric eval_class = build_metric(config['Metric']) # load pretrain model pre_best_model_dict = load_model(config, model, optimizer) logger.info( 'train dataloader has {} iters, valid dataloader has {} iters'.format( len(train_dataloader), len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer)