示例#1
0
def main(config, device, logger, vdl_writer):
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        config['Architecture']["Head"]['out_channels'] = len(
            getattr(post_process_class, 'character'))

    model = build_model(config['Architecture'])

    load_model(config, model)

    # create data ops
    transforms = []
    use_padding = False
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        if op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image']
        if op_name == "ResizeTableImage":
            use_padding = True
            padding_max_len = op['ResizeTableImage']['max_len']
        transforms.append(op)

    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    model.eval()
    for file in get_image_file_list(config['Global']['infer_img']):
        logger.info("infer_img: {}".format(file))
        with open(file, 'rb') as f:
            img = f.read()
            data = {'image': img}
        batch = transform(data, ops)
        images = np.expand_dims(batch[0], axis=0)
        images = paddle.to_tensor(images)
        preds = model(images)
        post_result = post_process_class(preds)
        res_html_code = post_result['res_html_code']
        res_loc = post_result['res_loc']
        img = cv2.imread(file)
        imgh, imgw = img.shape[0:2]
        res_loc_final = []
        for rno in range(len(res_loc[0])):
            x0, y0, x1, y1 = res_loc[0][rno]
            left = max(int(imgw * x0), 0)
            top = max(int(imgh * y0), 0)
            right = min(int(imgw * x1), imgw - 1)
            bottom = min(int(imgh * y1), imgh - 1)
            cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
            res_loc_final.append([left, top, right, bottom])
        res_loc_str = json.dumps(res_loc_final)
        logger.info("result: {}, {}".format(res_html_code, res_loc_final))
    logger.info("success!")
示例#2
0
def main():
    global_config = config['Global']

    # build model
    model = build_model(config['Architecture'])
    load_model(config, model)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        transforms.append(op)

    data_dir = config['Eval']['dataset']['data_dir']

    ops = create_operators(transforms, global_config)

    save_res_path = config['Global']['save_res_path']
    class_path = config['Global']['class_path']
    idx_to_cls = read_class_list(class_path)
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()

    warmup_times = 0
    count_t = []
    with open(save_res_path, "wb") as fout:
        with open(config['Global']['infer_img'], "rb") as f:
            lines = f.readlines()
            for index, data_line in enumerate(lines):
                if index == 10:
                    warmup_t = time.time()
                data_line = data_line.decode('utf-8')
                substr = data_line.strip("\n").split("\t")
                img_path, label = data_dir + "/" + substr[0], substr[1]
                data = {'img_path': img_path, 'label': label}
                with open(data['img_path'], 'rb') as f:
                    img = f.read()
                    data['image'] = img
                st = time.time()
                batch = transform(data, ops)
                batch_pred = [0] * len(batch)
                for i in range(len(batch)):
                    batch_pred[i] = paddle.to_tensor(
                        np.expand_dims(batch[i], axis=0))
                st = time.time()
                node, edge = model(batch_pred)
                node = F.softmax(node, -1)
                count_t.append(time.time() - st)
                draw_kie_result(batch, node, idx_to_cls, index)
    logger.info("success!")
    logger.info("It took {} s for predict {} images.".format(
        np.sum(count_t), len(count_t)))
    ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
    logger.info("The ips is {} images/s".format(ips))
示例#3
0
def main():
    global_config = config['Global']

    # build model
    model = build_model(config['Architecture'])

    load_model(config, model)

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image', 'shape']
        transforms.append(op)

    ops = create_operators(transforms, global_config)

    save_res_path = config['Global']['save_res_path']
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()
    with open(save_res_path, "wb") as fout:
        for file in get_image_file_list(config['Global']['infer_img']):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)
            images = np.expand_dims(batch[0], axis=0)
            shape_list = np.expand_dims(batch[1], axis=0)
            images = paddle.to_tensor(images)
            preds = model(images)
            post_result = post_process_class(preds, shape_list)
            points, strs = post_result['points'], post_result['texts']
            # write resule
            dt_boxes_json = []
            for poly, str in zip(points, strs):
                tmp_json = {"transcription": str}
                tmp_json['points'] = poly.tolist()
                dt_boxes_json.append(tmp_json)
            otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
            fout.write(otstr.encode())
            src_img = cv2.imread(file)
            draw_e2e_res(points, strs, config, src_img, file)
    logger.info("success!")
示例#4
0
def main():
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
    logger = get_logger()
    # build post process

    post_process_class = build_post_process(config["PostProcess"],
                                            config["Global"])

    # build model
    # for rec algorithm
    if hasattr(post_process_class, "character"):
        char_num = len(getattr(post_process_class, "character"))
        if config["Architecture"]["algorithm"] in ["Distillation",
                                                   ]:  # distillation model
            for key in config["Architecture"]["Models"]:
                config["Architecture"]["Models"][key]["Head"][
                    "out_channels"] = char_num
                # just one final tensor needs to to exported for inference
                config["Architecture"]["Models"][key][
                    "return_all_feats"] = False
        else:  # base rec model
            config["Architecture"]["Head"]["out_channels"] = char_num
    model = build_model(config["Architecture"])
    load_model(config, model)
    model.eval()

    save_path = config["Global"]["save_inference_dir"]

    arch_config = config["Architecture"]

    if arch_config["algorithm"] in ["Distillation", ]:  # distillation model
        archs = list(arch_config["Models"].values())
        for idx, name in enumerate(model.model_name_list):
            sub_model_save_path = os.path.join(save_path, name, "inference")
            export_single_model(model.model_list[idx], archs[idx],
                                sub_model_save_path, logger)
    else:
        save_path = os.path.join(save_path, "inference")
        export_single_model(model, arch_config, save_path, logger)
示例#5
0
def main():
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    model = build_model(config['Architecture'])

    load_model(config, model)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image']
        transforms.append(op)
    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    model.eval()
    for file in get_image_file_list(config['Global']['infer_img']):
        logger.info("infer_img: {}".format(file))
        with open(file, 'rb') as f:
            img = f.read()
            data = {'image': img}
        batch = transform(data, ops)

        images = np.expand_dims(batch[0], axis=0)
        images = paddle.to_tensor(images)
        preds = model(images)
        post_result = post_process_class(preds)
        for rec_reuslt in post_result:
            logger.info('\t result: {}'.format(rec_reuslt))
    logger.info("success!")
示例#6
0
def main():
    global_config = config['Global']
    # build dataloader
    valid_dataloader = build_dataloader(config, 'Eval', device, logger)

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in [
                "Distillation",
        ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])
    extra_input = config['Architecture']['algorithm'] in [
        "SRN", "NRTR", "SAR", "SEED"
    ]
    if "model_type" in config['Architecture'].keys():
        model_type = config['Architecture']['model_type']
    else:
        model_type = None

    best_model_dict = load_model(config, model)
    if len(best_model_dict):
        logger.info('metric in ckpt ***************')
        for k, v in best_model_dict.items():
            logger.info('{}:{}'.format(k, v))

    # build metric
    eval_class = build_metric(config['Metric'])
    # start eval
    metric = program.eval(model, valid_dataloader, post_process_class,
                          eval_class, model_type, extra_input)
    logger.info('metric eval ***************')
    for k, v in metric.items():
        logger.info('{}:{}'.format(k, v))
示例#7
0
def main():
    global_config = config['Global']
    # build dataloader
    config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
    config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
        'data_dir']
    config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
        'label_file_list']
    eval_dataloader = build_dataloader(config, 'Eval', device, logger)

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        config['Architecture']["Head"]['out_channels'] = char_num

    #set return_features = True
    config['Architecture']["Head"]["return_feats"] = True

    model = build_model(config['Architecture'])

    best_model_dict = load_model(config, model)
    if len(best_model_dict):
        logger.info('metric in ckpt ***************')
        for k, v in best_model_dict.items():
            logger.info('{}:{}'.format(k, v))

    # get features from train data
    char_center = program.get_center(model, eval_dataloader,
                                     post_process_class)

    #serialize to disk
    with open("train_center.pkl", 'wb') as f:
        pickle.dump(char_center, f)
    return
示例#8
0
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()

    global_config = config['Global']

    # build dataloader
    train_dataloader = build_dataloader(config, 'Train', device, logger)
    if len(train_dataloader) == 0:
        logger.error(
            "No Images in train dataset, please ensure\n" +
            "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
            +
            "\t2. The annotation file and path in the configuration file are provided normally."
        )
        return

    if config['Eval']:
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
    else:
        valid_dataloader = None

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in [
                "Distillation",
        ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])
    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

    # build loss
    loss_class = build_loss(config['Loss'])

    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
        epochs=config['Global']['epoch_num'],
        step_each_epoch=len(train_dataloader),
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer)
    logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
    if valid_dataloader is not None:
        logger.info('valid dataloader has {} iters'.format(
            len(valid_dataloader)))

    use_amp = config["Global"].get("use_amp", False)
    if use_amp:
        AMP_RELATED_FLAGS_SETTING = {
            'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
            'FLAGS_max_inplace_grad_add': 8,
        }
        paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
        scale_loss = config["Global"].get("scale_loss", 1.0)
        use_dynamic_loss_scaling = config["Global"].get(
            "use_dynamic_loss_scaling", False)
        scaler = paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
    else:
        scaler = None

    # start train
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
示例#9
0
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()

    global_config = config['Global']

    # build dataloader
    train_dataloader = build_dataloader(config, 'Train', device, logger)
    if config['Eval']:
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
    else:
        valid_dataloader = None

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        config['Architecture']["Head"]['out_channels'] = char_num
    model = build_model(config['Architecture'])
    if config['Architecture']['model_type'] == 'det':
        input_shape = [1, 3, 640, 640]
    elif config['Architecture']['model_type'] == 'rec':
        input_shape = [1, 3, 32, 320]
    flops = paddle.flops(model, input_shape)

    logger.info("FLOPs before pruning: {}".format(flops))

    from paddleslim.dygraph import FPGMFilterPruner
    model.train()

    pruner = FPGMFilterPruner(model, input_shape)

    # build loss
    loss_class = build_loss(config['Loss'])

    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
        epochs=config['Global']['epoch_num'],
        step_each_epoch=len(train_dataloader),
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer)

    logger.info(
        'train dataloader has {} iters, valid dataloader has {} iters'.format(
            len(train_dataloader), len(valid_dataloader)))
    # build metric
    eval_class = build_metric(config['Metric'])

    logger.info(
        'train dataloader has {} iters, valid dataloader has {} iters'.format(
            len(train_dataloader), len(valid_dataloader)))

    def eval_fn():
        metric = program.eval(model, valid_dataloader, post_process_class,
                              eval_class, False)
        if config['Architecture']['model_type'] == 'det':
            main_indicator = 'hmean'
        else:
            main_indicator = 'acc'

        logger.info("metric[{}]: {}".format(main_indicator,
                                            metric[main_indicator]))
        return metric[main_indicator]

    run_sensitive_analysis = False
    """
    run_sensitive_analysis=True: 
        Automatically compute the sensitivities of convolutions in a model. 
        The sensitivity of a convolution is the losses of accuracy on test dataset in 
        differenct pruned ratios. The sensitivities can be used to get a group of best 
        ratios with some condition.
    
    run_sensitive_analysis=False: 
        Set prune trim ratio to a fixed value, such as 10%. The larger the value, 
        the more convolution weights will be cropped.

    """

    if run_sensitive_analysis:
        params_sensitive = pruner.sensitive(
            eval_func=eval_fn,
            sen_file="./deploy/slim/prune/sen.pickle",
            skip_vars=[
                "conv2d_57.w_0", "conv2d_transpose_2.w_0",
                "conv2d_transpose_3.w_0"
            ])
        logger.info(
            "The sensitivity analysis results of model parameters saved in sen.pickle"
        )
        # calculate pruned params's ratio
        params_sensitive = pruner._get_ratios_by_loss(params_sensitive,
                                                      loss=0.02)
        for key in params_sensitive.keys():
            logger.info("{}, {}".format(key, params_sensitive[key]))
    else:
        params_sensitive = {}
        for param in model.parameters():
            if 'transpose' not in param.name and 'linear' not in param.name:
                # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
                params_sensitive[param.name] = 0.1

    plan = pruner.prune_vars(params_sensitive, [0])

    flops = paddle.flops(model, input_shape)
    logger.info("FLOPs after pruning: {}".format(flops))

    # start train

    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer)
示例#10
0
def main(config, device, logger, vdl_writer):

    global_config = config['Global']

    # build dataloader
    valid_dataloader = build_dataloader(config, 'Eval', device, logger)

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        config['Architecture']["Head"]['out_channels'] = char_num
    model = build_model(config['Architecture'])

    if config['Architecture']['model_type'] == 'det':
        input_shape = [1, 3, 640, 640]
    elif config['Architecture']['model_type'] == 'rec':
        input_shape = [1, 3, 32, 320]

    flops = paddle.flops(model, input_shape)
    logger.info("FLOPs before pruning: {}".format(flops))

    from paddleslim.dygraph import FPGMFilterPruner
    model.train()
    pruner = FPGMFilterPruner(model, input_shape)

    # build metric
    eval_class = build_metric(config['Metric'])

    def eval_fn():
        metric = program.eval(model, valid_dataloader, post_process_class,
                              eval_class)
        if config['Architecture']['model_type'] == 'det':
            main_indicator = 'hmean'
        else:
            main_indicator = 'acc'
        logger.info("metric[{}]: {}".format(main_indicator, metric[
            main_indicator]))
        return metric[main_indicator]

    params_sensitive = pruner.sensitive(
        eval_func=eval_fn,
        sen_file="./sen.pickle",
        skip_vars=[
            "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
        ])

    logger.info(
        "The sensitivity analysis results of model parameters saved in sen.pickle"
    )
    # calculate pruned params's ratio
    params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
    for key in params_sensitive.keys():
        logger.info("{}, {}".format(key, params_sensitive[key]))

    plan = pruner.prune_vars(params_sensitive, [0])

    flops = paddle.flops(model, input_shape)
    logger.info("FLOPs after pruning: {}".format(flops))

    # load pretrain model
    load_model(config, model)
    metric = program.eval(model, valid_dataloader, post_process_class,
                          eval_class)
    if config['Architecture']['model_type'] == 'det':
        main_indicator = 'hmean'
    else:
        main_indicator = 'acc'
    logger.info("metric['']: {}".format(main_indicator, metric[main_indicator]))

    # start export model
    from paddle.jit import to_static

    infer_shape = [3, -1, -1]
    if config['Architecture']['model_type'] == "rec":
        infer_shape = [3, 32, -1]  # for rec model, H must be 32

        if 'Transform' in config['Architecture'] and config['Architecture'][
                'Transform'] is not None and config['Architecture'][
                    'Transform']['name'] == 'TPS':
            logger.info(
                'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
            )
            infer_shape[-1] = 100
    model = to_static(
        model,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None] + infer_shape, dtype='float32')
        ])

    save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
    paddle.jit.save(model, save_path)
    logger.info('inference model is saved to {}'.format(save_path))
示例#11
0
def main():
    ############################################################################################################
    # 1. quantization configs
    ############################################################################################################
    quant_config = {
        # weight preprocess type, default is None and no preprocessing is performed.
        'weight_preprocess_type': None,
        # activation preprocess type, default is None and no preprocessing is performed.
        'activation_preprocess_type': None,
        # weight quantize type, default is 'channel_wise_abs_max'
        'weight_quantize_type': 'channel_wise_abs_max',
        # activation quantize type, default is 'moving_average_abs_max'
        'activation_quantize_type': 'moving_average_abs_max',
        # weight quantize bit num, default is 8
        'weight_bits': 8,
        # activation quantize bit num, default is 8
        'activation_bits': 8,
        # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
        'dtype': 'int8',
        # window size for 'range_abs_max' quantization. default is 10000
        'window_size': 10000,
        # The decay coefficient of moving average, default is 0.9
        'moving_rate': 0.9,
        # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
        'quantizable_layer_type': ['Conv2D', 'Linear'],
    }
    FLAGS = ArgsParser().parse_args()
    config = load_config(FLAGS.config)
    merge_config(FLAGS.opt)
    logger = get_logger()
    # build post process

    post_process_class = build_post_process(config['PostProcess'],
                                            config['Global'])

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in [
                "Distillation",
        ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])

    # get QAT model
    quanter = QAT(config=quant_config)
    quanter.quantize(model)

    load_model(config, model)
    model.eval()

    # build metric
    eval_class = build_metric(config['Metric'])

    # build dataloader
    valid_dataloader = build_dataloader(config, 'Eval', device, logger)

    use_srn = config['Architecture']['algorithm'] == "SRN"
    model_type = config['Architecture'].get('model_type', None)
    # start eval
    metric = program.eval(model, valid_dataloader, post_process_class,
                          eval_class, model_type, use_srn)

    logger.info('metric eval ***************')
    for k, v in metric.items():
        logger.info('{}:{}'.format(k, v))

    infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]

    save_path = config["Global"]["save_inference_dir"]

    arch_config = config["Architecture"]
    if arch_config["algorithm"] in [
            "Distillation",
    ]:  # distillation model
        for idx, name in enumerate(model.model_name_list):
            sub_model_save_path = os.path.join(save_path, name, "inference")
            export_single_model(quanter, model.model_list[idx], infer_shape,
                                sub_model_save_path, logger)
    else:
        save_path = os.path.join(save_path, "inference")
        export_single_model(quanter, model, infer_shape, save_path, logger)
示例#12
0
def main():
    global_config = config['Global']

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in [
                "Distillation",
        ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num

    model = build_model(config['Architecture'])

    load_model(config, model)

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name in ['RecResizeImg']:
            op[op_name]['infer_mode'] = True
        elif op_name == 'KeepKeys':
            if config['Architecture']['algorithm'] == "SRN":
                op[op_name]['keep_keys'] = [
                    'image', 'encoder_word_pos', 'gsrm_word_pos',
                    'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
                ]
            elif config['Architecture']['algorithm'] == "SAR":
                op[op_name]['keep_keys'] = ['image', 'valid_ratio']
            else:
                op[op_name]['keep_keys'] = ['image']
        transforms.append(op)
    global_config['infer_mode'] = True
    ops = create_operators(transforms, global_config)

    save_res_path = config['Global'].get('save_res_path',
                                         "./output/rec/predicts_rec.txt")
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()

    with open(save_res_path, "w") as fout:
        for file in get_image_file_list(config['Global']['infer_img']):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)
            if config['Architecture']['algorithm'] == "SRN":
                encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
                gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
                gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
                gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)

                others = [
                    paddle.to_tensor(encoder_word_pos_list),
                    paddle.to_tensor(gsrm_word_pos_list),
                    paddle.to_tensor(gsrm_slf_attn_bias1_list),
                    paddle.to_tensor(gsrm_slf_attn_bias2_list)
                ]
            if config['Architecture']['algorithm'] == "SAR":
                valid_ratio = np.expand_dims(batch[-1], axis=0)
                img_metas = [paddle.to_tensor(valid_ratio)]

            images = np.expand_dims(batch[0], axis=0)
            images = paddle.to_tensor(images)
            if config['Architecture']['algorithm'] == "SRN":
                preds = model(images, others)
            elif config['Architecture']['algorithm'] == "SAR":
                preds = model(images, img_metas)
            else:
                preds = model(images)
            post_result = post_process_class(preds)
            info = None
            if isinstance(post_result, dict):
                rec_info = dict()
                for key in post_result:
                    if len(post_result[key][0]) >= 2:
                        rec_info[key] = {
                            "label": post_result[key][0][0],
                            "score": float(post_result[key][0][1]),
                        }
                info = json.dumps(rec_info)
            else:
                if len(post_result[0]) >= 2:
                    info = post_result[0][0] + "\t" + str(post_result[0][1])

            if info is not None:
                logger.info("\t result: {}".format(info))
                fout.write(os.path.basename(file) + "\t" + info + "\n")
    logger.info("success!")
示例#13
0
def main():
    global_config = config['Global']

    # build model
    model = build_model(config['Architecture'])

    load_model(config, model)
    # build post process
    post_process_class = build_post_process(config['PostProcess'])

    # create data ops
    transforms = []
    for op in config['Eval']['dataset']['transforms']:
        op_name = list(op)[0]
        if 'Label' in op_name:
            continue
        elif op_name == 'KeepKeys':
            op[op_name]['keep_keys'] = ['image', 'shape']
        transforms.append(op)

    ops = create_operators(transforms, global_config)

    save_res_path = config['Global']['save_res_path']
    if not os.path.exists(os.path.dirname(save_res_path)):
        os.makedirs(os.path.dirname(save_res_path))

    model.eval()
    with open(save_res_path, "wb") as fout:
        for file in get_image_file_list(config['Global']['infer_img']):
            logger.info("infer_img: {}".format(file))
            with open(file, 'rb') as f:
                img = f.read()
                data = {'image': img}
            batch = transform(data, ops)

            images = np.expand_dims(batch[0], axis=0)
            shape_list = np.expand_dims(batch[1], axis=0)
            images = paddle.to_tensor(images)
            preds = model(images)
            post_result = post_process_class(preds, shape_list)

            src_img = cv2.imread(file)

            dt_boxes_json = []
            # parser boxes if post_result is dict
            if isinstance(post_result, dict):
                det_box_json = {}
                for k in post_result.keys():
                    boxes = post_result[k][0]['points']
                    dt_boxes_list = []
                    for box in boxes:
                        tmp_json = {"transcription": ""}
                        tmp_json['points'] = box.tolist()
                        dt_boxes_list.append(tmp_json)
                    det_box_json[k] = dt_boxes_list
                    save_det_path = os.path.dirname(
                        config['Global']
                        ['save_res_path']) + "/det_results_{}/".format(k)
                    draw_det_res(boxes, config, src_img, file, save_det_path)
            else:
                boxes = post_result[0]['points']
                dt_boxes_json = []
                # write result
                for box in boxes:
                    tmp_json = {"transcription": ""}
                    tmp_json['points'] = box.tolist()
                    dt_boxes_json.append(tmp_json)
                save_det_path = os.path.dirname(
                    config['Global']['save_res_path']) + "/det_results/"
                draw_det_res(boxes, config, src_img, file, save_det_path)
            otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
            fout.write(otstr.encode())

    logger.info("success!")
示例#14
0
def main(config, device, logger, vdl_writer):
    # init dist environment
    if config['Global']['distributed']:
        dist.init_parallel_env()

    global_config = config['Global']

    # build dataloader
    train_dataloader = build_dataloader(config, 'Train', device, logger)
    if config['Eval']:
        valid_dataloader = build_dataloader(config, 'Eval', device, logger)
    else:
        valid_dataloader = None

    # build post process
    post_process_class = build_post_process(config['PostProcess'],
                                            global_config)

    # build model
    # for rec algorithm
    if hasattr(post_process_class, 'character'):
        char_num = len(getattr(post_process_class, 'character'))
        if config['Architecture']["algorithm"] in [
                "Distillation",
        ]:  # distillation model
            for key in config['Architecture']["Models"]:
                config['Architecture']["Models"][key]["Head"][
                    'out_channels'] = char_num
        else:  # base rec model
            config['Architecture']["Head"]['out_channels'] = char_num
    model = build_model(config['Architecture'])

    quanter = QAT(config=quant_config, act_preprocess=PACT)
    quanter.quantize(model)

    if config['Global']['distributed']:
        model = paddle.DataParallel(model)

    # build loss
    loss_class = build_loss(config['Loss'])

    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config['Optimizer'],
        epochs=config['Global']['epoch_num'],
        step_each_epoch=len(train_dataloader),
        parameters=model.parameters())

    # build metric
    eval_class = build_metric(config['Metric'])
    # load pretrain model
    pre_best_model_dict = load_model(config, model, optimizer)

    logger.info(
        'train dataloader has {} iters, valid dataloader has {} iters'.format(
            len(train_dataloader), len(valid_dataloader)))

    # start train
    program.train(config, train_dataloader, valid_dataloader, device, model,
                  loss_class, optimizer, lr_scheduler, post_process_class,
                  eval_class, pre_best_model_dict, logger, vdl_writer)