Esempio n. 1
0
def eval_rec_run(exe, config, eval_info_dict, mode):
    """
    Run evaluation program, return program outputs.
    """
    char_ops = config['Global']['char_ops']
    total_loss = 0
    total_sample_num = 0
    total_acc_num = 0
    total_batch_num = 0
    if mode == "eval":
        is_remove_duplicate = False
    else:
        is_remove_duplicate = True

    for data in eval_info_dict['reader']():
        img_num = len(data)
        img_list = []
        label_list = []
        for ino in range(img_num):
            img_list.append(data[ino][0])
            label_list.append(data[ino][1])

        if config['Global']['loss_type'] != "srn":
            img_list = np.concatenate(img_list, axis=0)
            outs = exe.run(eval_info_dict['program'], \
                       feed={'image': img_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
            preds = np.array(outs[0])

            if config['Global']['loss_type'] == "attention":
                preds, preds_lod = convert_rec_attention_infer_res(preds)
            else:
                preds_lod = outs[0].lod()[0]
            labels, labels_lod = convert_rec_label_to_lod(label_list)
            acc, acc_num, sample_num = cal_predicts_accuracy(
                char_ops, preds, preds_lod, labels, labels_lod,
                is_remove_duplicate)
        else:
            encoder_word_pos_list = []
            gsrm_word_pos_list = []
            gsrm_slf_attn_bias1_list = []
            gsrm_slf_attn_bias2_list = []
            for ino in range(img_num):
                encoder_word_pos_list.append(data[ino][2])
                gsrm_word_pos_list.append(data[ino][3])
                gsrm_slf_attn_bias1_list.append(data[ino][4])
                gsrm_slf_attn_bias2_list.append(data[ino][5])

            img_list = np.concatenate(img_list, axis=0)
            label_list = np.concatenate(label_list, axis=0)
            encoder_word_pos_list = np.concatenate(
                encoder_word_pos_list, axis=0).astype(np.int64)
            gsrm_word_pos_list = np.concatenate(
                gsrm_word_pos_list, axis=0).astype(np.int64)
            gsrm_slf_attn_bias1_list = np.concatenate(
                gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
            gsrm_slf_attn_bias2_list = np.concatenate(
                gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)

            labels = label_list

            outs = exe.run(eval_info_dict['program'], \
                       feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
                             'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
                             'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
                       fetch_list=eval_info_dict['fetch_varname_list'], \
                       return_numpy=False)
            preds = np.array(outs[0])
            acc, acc_num, sample_num = cal_predicts_accuracy_srn(
                char_ops, preds, labels, config['Global']['max_text_length'])

        total_acc_num += acc_num
        total_sample_num += sample_num
        #logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
        total_batch_num += 1
    avg_acc = total_acc_num * 1.0 / total_sample_num
    metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
               "total_sample_num": total_sample_num}
    return metrics
Esempio n. 2
0
def train_eval_rec_run(config,
                       exe,
                       train_info_dict,
                       eval_info_dict,
                       is_slim=None):
    """
    Feed data to the model and fetch the measures and loss for recognition
    Args:
        config: config
        exe:
        train_info_dict: information dict for training
        eval_info_dict: information dict for evaluation
    """
    train_batch_id = 0
    log_smooth_window = config['Global']['log_smooth_window']
    epoch_num = config['Global']['epoch_num']
    print_batch_step = config['Global']['print_batch_step']
    eval_batch_step = config['Global']['eval_batch_step']
    start_eval_step = 0
    if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
        start_eval_step = eval_batch_step[0]
        eval_batch_step = eval_batch_step[1]
        logger.info(
            "During the training process, after the {}th iteration, an evaluation is run every {} iterations"
            .format(start_eval_step, eval_batch_step))
    save_epoch_step = config['Global']['save_epoch_step']
    save_model_dir = config['Global']['save_model_dir']
    if not os.path.exists(save_model_dir):
        os.makedirs(save_model_dir)
    train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
    best_eval_acc = -1
    best_batch_id = 0
    best_epoch = 0
    train_loader = train_info_dict['reader']
    for epoch in range(epoch_num):
        train_loader.start()
        try:
            while True:
                t1 = time.time()
                train_outs = exe.run(
                    program=train_info_dict['compile_program'],
                    fetch_list=train_info_dict['fetch_varname_list'],
                    return_numpy=False)
                fetch_map = dict(
                    zip(train_info_dict['fetch_name_list'],
                        range(len(train_outs))))

                loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
                lr = np.mean(np.array(train_outs[fetch_map['lr']]))
                preds_idx = fetch_map['decoded_out']
                preds = np.array(train_outs[preds_idx])
                labels_idx = fetch_map['label']
                labels = np.array(train_outs[labels_idx])

                if config['Global']['loss_type'] != 'srn':
                    preds_lod = train_outs[preds_idx].lod()[0]
                    labels_lod = train_outs[labels_idx].lod()[0]

                    acc, acc_num, img_num = cal_predicts_accuracy(
                        config['Global']['char_ops'], preds, preds_lod, labels,
                        labels_lod)
                else:
                    acc, acc_num, img_num = cal_predicts_accuracy_srn(
                        config['Global']['char_ops'], preds, labels,
                        config['Global']['max_text_length'])
                t2 = time.time()
                train_batch_elapse = t2 - t1
                stats = {'loss': loss, 'acc': acc}
                train_stats.update(stats)
                if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
                    % print_batch_step == 0:
                    logs = train_stats.log()
                    strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
                        epoch, train_batch_id, lr, logs, train_batch_elapse)
                    logger.info(strs)

                if train_batch_id > 0 and\
                    train_batch_id % eval_batch_step == 0:
                    model_average = train_info_dict['model_average']
                    if model_average != None:
                        model_average.apply(exe)
                    metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
                    eval_acc = metrics['avg_acc']
                    eval_sample_num = metrics['total_sample_num']
                    if eval_acc > best_eval_acc:
                        best_eval_acc = eval_acc
                        best_batch_id = train_batch_id
                        best_epoch = epoch
                        save_path = save_model_dir + "/best_accuracy"
                        if is_slim is None:
                            save_model(train_info_dict['train_program'],
                                       save_path)
                        else:
                            import paddleslim as slim
                            if is_slim == "prune":
                                slim.prune.save_model(
                                    exe, train_info_dict['train_program'],
                                    save_path)
                            elif is_slim == "quant":
                                save_model(eval_info_dict['program'],
                                           save_path)
                            else:
                                raise ValueError(
                                    "Only quant and prune are supported currently. But received {}"
                                    .format(is_slim))
                    strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
                        train_batch_id, eval_acc, best_eval_acc, best_epoch,
                        best_batch_id, eval_sample_num)
                    logger.info(strs)
                train_batch_id += 1

        except fluid.core.EOFException:
            train_loader.reset()
        if epoch == 0 and save_epoch_step == 1:
            save_path = save_model_dir + "/iter_epoch_0"
            if is_slim is None:
                save_model(train_info_dict['train_program'], save_path)
            else:
                import paddleslim as slim
                if is_slim == "prune":
                    slim.prune.save_model(exe,
                                          train_info_dict['train_program'],
                                          save_path)
                elif is_slim == "quant":
                    save_model(eval_info_dict['program'], save_path)
                else:
                    raise ValueError(
                        "Only quant and prune are supported currently. But received {}"
                        .format(is_slim))
        if epoch > 0 and epoch % save_epoch_step == 0:
            save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
            if is_slim is None:
                save_model(train_info_dict['train_program'], save_path)
            else:
                import paddleslim as slim
                if is_slim == "prune":
                    slim.prune.save_model(exe,
                                          train_info_dict['train_program'],
                                          save_path)
                elif is_slim == "quant":
                    save_model(eval_info_dict['program'], save_path)
                else:
                    raise ValueError(
                        "Only quant and prune are supported currently. But received {}"
                        .format(is_slim))
    return