Ejemplo n.º 1
0
def dataset_test(lmdb_path, batch_size):
    dataset = LmdbDataset(lmdb_path)
    dataloader = DataLoader(dataset, batch_size, shuffle=False, num_workers=0)

    for i, data in enumerate(dataloader):
        img, label = data
        print(i, img, label)
        print(i, img.shape, label.shape)
Ejemplo n.º 2
0
 def setUp(self):
     normalizeFunc = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                          std=[0.229, 0.224, 0.225])
     transf = transforms.Compose([
         transforms.ToPILImage(),
         transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(), normalizeFunc
     ])
     self.dataset = LmdbDataset("val.lmdb", transform=transf)
     self.dataloader = DataLoader(self.dataset,
                                  batch_size=4,
                                  shuffle=True,
                                  num_workers=4)
def main(opt):
    # 'None' corresponds to the clean data
    transforms = [None]
    # Make tests reproducible
    rng = np.random.default_rng(opt.seed)
    corruptions = [
        Curve(rng=rng),
        Distort(rng),
        Stretch(rng),
        Rotate(rng=rng),
        Perspective(rng),
        Shrink(rng),
        TranslateX(rng),
        TranslateY(rng),
        VGrid(rng),
        HGrid(rng),
        Grid(rng),
        RectGrid(rng),
        EllipseGrid(rng),
        GaussianNoise(rng),
        ShotNoise(rng),
        ImpulseNoise(rng),
        SpeckleNoise(rng),
        GaussianBlur(rng),
        DefocusBlur(rng),
        MotionBlur(rng),
        GlassBlur(rng),
        ZoomBlur(rng),
        Contrast(rng),
        Brightness(rng),
        JpegCompression(rng),
        Pixelate(rng),
        Fog(rng),
        Snow(rng),
        Frost(rng),
        Rain(rng),
        Shadow(rng),
        Posterize(rng),
        Solarize(rng),
        Invert(rng),
        Equalize(rng),
        AutoContrast(rng),
        Sharpness(rng),
        Color(rng)
    ]
    # Generate partial functions for the three severity levels
    for c in corruptions:
        for level in range(1):
            p = partial(c, mag=level)
            p.__name__ = '{}-{}'.format(c.__class__.__name__, level)
            transforms.append(p)

    for tr in transforms:
        name = 'Clean' if tr is None else tr.__name__
        for d in os.listdir(opt.eval_data):
            outdir = os.path.join('corrupted-data', name, d)
            os.makedirs(outdir)
            for i, (img, label) in enumerate(
                    LmdbDataset(os.path.join(opt.eval_data, d), opt, tr)):
                print(outdir, i)
                #img = img.resize((224, 224))
                img = img.resize((100, 32))
                if tr is not None:
                    img = tr(img)
                #img = img.resize((100, 32))
                img.save(os.path.join(outdir, '{:04d}.png'.format(i)))
Ejemplo n.º 4
0
model = model.to(config.device)
if config.device == 'cuda':
    model = torch.nn.DataParallel(model)

model.eval()

test_data_dir = "../data/data_lmdb_release/evaluation/"
test_data_set = [
    "IIIT5k_3000", "SVT", "IC03_867", "IC13_1015", "IC15_1811", "SVTP",
    "CUTE80"
]
device = config.device

for test_data in test_data_set:
    path = test_data_dir + test_data
    test_dataset = LmdbDataset(path, config.lmdb_config)
    data_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        num_workers=4,
        shuffle=False,
        pin_memory=True,
        drop_last=False,
    )
    test_data += '(%d)' % (len(test_dataset))
    targets = []
    pred_rec = []
    for i, data_in in enumerate(data_loader):

        if test_dataset.use_bidecoder:
            imgs, labels1, labels2, lengths = data_in
Ejemplo n.º 5
0
def train():
    """ dataset preparation """
    train_dataset_lmdb = LmdbDataset(cfg.lmdb_trainset_dir_name)
    val_dataset_lmdb = LmdbDataset(cfg.lmdb_valset_dir_name)

    train_loader = torch.utils.data.DataLoader(
        train_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,
        num_workers=int(cfg.workers),
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        val_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(cfg.workers),
        pin_memory=True)

    # --------------------训练过程---------------------------------
    model = advancedEAST()
    if int(cfg.train_task_id[-3:]) != 256:
        id_num = cfg.train_task_id[-3:]
        idx_dic = {'384': 256, '512': 384, '640': 512, '736': 640}
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(idx_dic[id_num])))
    elif os.path.exists('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)):
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)))

    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay)
    loss_func = quad_loss

    train_Loss_list = []
    val_Loss_list = []

    '''start training'''
    start_iter = 0
    if cfg.saved_model != '':
        try:
            start_iter = int(cfg.saved_model.split('_')[-1].split('.')[0])
            print('continue to train, start_iter: {}'.format(start_iter))
        except Exception as e:
            print(e)
            pass

    start_time = time.time()
    best_mF1_score = 0
    i = start_iter
    step_num = 0
    start_time = time.time()
    loss_avg = Averager()
    val_loss_avg = Averager()
    eval_p_r_f = eval_pre_rec_f1()

    while(True):
        model.train()
        # train part
        # training-----------------------------
        for image_tensors, labels, gt_xy_list in train_loader:
            step_num += 1
            batch_x = image_tensors.to(device).float()
            batch_y = labels.to(device).float()  # float64转float32

            out = model(batch_x)
            loss = loss_func(batch_y, out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_avg.add(loss)
            train_Loss_list.append(loss_avg.val())
            if i == 5 or (i + 1) % 10 == 0:
                eval_p_r_f.add(out, gt_xy_list)  # 非常耗时!!!

        # save model per 100 epochs.
        if (i + 1) % 1e+2 == 0:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, step_num+1))

        print('Epoch:[{}/{}] Training Loss: {:.3f}'.format(i + 1, cfg.epoch_num, train_Loss_list[-1].item()))
        loss_avg.reset()

        if i == 5 or (i + 1) % 10 == 0:
            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('Training meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

        # evaluation--------------------------------
        if (i + 1) % cfg.valInterval == 0:
            elapsed_time = time.time() - start_time
            print('Elapsed time:{}s'.format(round(elapsed_time)))
            model.eval()
            for image_tensors, labels, gt_xy_list in valid_loader:
                batch_x = image_tensors.to(device)
                batch_y = labels.to(device).float()  # float64转float32

                out = model(batch_x)
                loss = loss_func(batch_y, out)

                val_loss_avg.add(loss)
                val_Loss_list.append(val_loss_avg.val())
                eval_p_r_f.add(out, gt_xy_list)

            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('validation meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

            if mF1_score > best_mF1_score:  # 记录最佳模型
                best_mF1_score = mF1_score
                torch.save(model.state_dict(), './saved_models/{}/{}_best_mF1_score_{:.3f}.pth'.format(cfg.train_task_id, cfg.train_task_id, mF1_score))
                torch.save(model.state_dict(), './saved_model/{}_best_mF1_score.pth'.format(cfg.train_task_id))

            print('Validation loss:{:.3f}'.format(val_loss_avg.val().item()))
            val_loss_avg.reset()

        if i == cfg.epoch_num:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, i+1))
            print('End the training')
            break
        i += 1

    sys.exit()
Ejemplo n.º 6
0
def train(field):
    alphabet = ''.join(json.load(open('./cn-alphabet.json', 'rb')))
    nclass = len(alphabet) + 1  # add the dash -
    batch_size = BATCH_SIZE
    if field == 'address' or field == 'psb':
        batch_size = 1  # image length varies

    converter = LabelConverter(alphabet)
    criterion = CTCLoss(zero_infinity=True)

    crnn = CRNN(IMAGE_HEIGHT, nc, nclass, number_hidden)
    crnn.apply(weights_init)

    image_transform = transforms.Compose([
        Rescale(IMAGE_HEIGHT),
        transforms.ToTensor(),
        Normalize()
    ])

    dataset = LmdbDataset(db_path, field, image_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            shuffle=True, num_workers=4)

    image = torch.FloatTensor(batch_size, 3, IMAGE_HEIGHT, IMAGE_HEIGHT)
    text = torch.IntTensor(batch_size * 5)
    length = torch.IntTensor(batch_size)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    loss_avg = utils.averager()
    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)

    if torch.cuda.is_available():
        crnn.cuda()
        crnn = nn.DataParallel(crnn)
        image = image.cuda()
        criterion = criterion.cuda()

    def train_batch(net, iteration):
        data = iteration.next()
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.load_data(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.load_data(text, t)
        utils.load_data(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        crnn.zero_grad()
        cost.backward()
        optimizer.step()
        return cost

    nepoch = 25
    for epoch in range(nepoch):
        train_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = train_batch(crnn, train_iter)
            loss_avg.add(cost)
            i += 1

            if i % 500 == 0:
                print('%s [%d/%d][%d/%d] Loss: %f' %
                        (datetime.datetime.now(), epoch, nepoch, i, len(dataloader), loss_avg.val()))
                loss_avg.reset()

            # do checkpointing
            if i % 500 == 0:
                torch.save(
                    crnn.state_dict(), f'{model_path}crnn_{field}_{epoch}_{i}.pth')
Ejemplo n.º 7
0
config = TainTestConfig()
# TIMESTAMP = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
writer = SummaryWriter(
    log_dir="../data/logs/" + config.name,
    flush_secs=30,
)

# config.lmdb_config.num_samples = 10000
# config.batch_size = 1024
n_device = torch.cuda.device_count()
config.batch_size = 256 * n_device
config.iter_to_valid = 128 * 8

train_dataset = torch.utils.data.ConcatDataset(
    [LmdbDataset(path, config.lmdb_config) for path in config.train_data])

data_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=2 * n_device,  #4,
    shuffle=True,
    pin_memory=True,
    drop_last=True,
)

path = "../data/lmdbs/evaluation/IIIT5K_3000"
config.batch_size = 1024
config.lmdb_config.num_samples = 1000
test_dataset = LmdbDataset(path, config.lmdb_config)
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device)))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    # demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_data = LmdbDataset(root=opt.image_folder, opt=opt,
                            mode='Val')  # use RawDataset

    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True,
                                              drop_last=True)

    log = open(f'./log_demo_result.txt', 'a')
    # predict
    model.eval()
    fail_count, sample_count = 0, 0
    record_count = 1
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for image_tensor, gt, pred, pred_max_prob in zip(
                    image_tensors, image_path_list, preds_str, preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                if pred_max_prob.shape[0] > 0:
                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                else:
                    confidence_score = 0.0

                # gt = img_name.split('_L_')[1]
                # gt = gt.split('.')[0]
                # pred = pred.split('.')[0]
                # except IndexError:
                #     print(f'Index Error {img_name}')
                #     raise IndexError
                # if img_name.find('1_225427_L_대전출입국관리사무소_L_21.png') >=0 :
                #     print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')

                # if gt.split('(')[0] != pred.split('(')[0]:
                #     fail_count += 1
                #     log.write(f'{gt:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')
                #     # import shutil
                # shutil.copy(gt, os.path.join('./result', os.path.basename(img_name)))

                # if gt.find('#') >= 0:
                #     continue

                compare_gt = "".join(x.upper() for x in gt if x.isalnum())
                compare_pred = "".join(x.upper() for x in pred if x.isalnum())

                if compare_gt != compare_pred:
                    fail_count += 1
                    print(
                        f'{gt:25s}\t{pred:25s}\tFail\t{confidence_score:0.4f}\t{record_count}\n'
                    )
                    im = to_pil_image(image_tensor)
                    try:
                        im.save(
                            os.path.join(
                                'result',
                                f'{fail_count}_{compare_pred}_{compare_gt}.jpeg'
                            ))
                    except Exception as e:
                        print(
                            f'Error: {e} {fail_count}_{compare_pred}_{compare_gt}'
                        )
                        exit(1)
                else:
                    # print(f'{gt:25s}\t{pred:25s}\tSuccess\t{confidence_score:0.4f}')
                    pass
                sample_count += 1
                record_count += 1
        log.close()
        print(f'total accuracy: {(sample_count-fail_count)/sample_count:.2f}')
Ejemplo n.º 9
0
 def setUp(self):
     self.dataset = LmdbDataset("train_lmdb")
     self.dataloader = DataLoader(self.dataset, batch_size=32, shuffle=True,
                                  num_workers=4)
Ejemplo n.º 10
0
def train(opt):
    """ 准备训练和验证的数据集 """
    transform = transforms.Compose([
        ToTensor(),
    ])
    train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )

    valid_dataset = LmdbDataset(root=opt.valid_data,
                                opt=opt,
                                transform=transform)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )
    print('-' * 80)
    """ 模型的配置 """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # 权重初始化
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    model = model.to(device)
    model.train()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        for image_tensors, labels in train_loader:
            image = image_tensors.to(device)
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length
            )  # text: [index, index, ..., index], length: [10, 8]
            batch_size = image.size(0)

            if 'CTC' in opt.Prediction:
                # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1)
                preds = model(image,
                              text).log_softmax(2)  # torch.Size([100, 63, 12])
                preds_size = torch.IntTensor([preds.size(1)] *
                                             batch_size).to(device)
                preds = preds.permute(
                    1, 0, 2
                )  # to use CTCLoss format  # 100 * 63 * 7 ->  63 * 100 * 7

                # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
                # https://github.com/jpuigcerver/PyLaia/issues/16
                torch.backends.cudnn.enabled = False
                cost = criterion(
                    preds, text, preds_size, length
                )  # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符
                # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符
                torch.backends.cudnn.enabled = True

            else:
                preds = model(image,
                              text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]),
                                 target.contiguous().view(-1))

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()

            loss_avg.add(cost)

            # validation part
            if i % opt.valInterval == 0:
                elapsed_time = time.time() - start_time
                print(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()

                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)
                    model.train()

                    for pred, gt in zip(preds[:5], labels[:5]):
                        if 'Attn' in opt.Prediction:
                            pred = pred[:pred.find('[s]')]
                            gt = gt[:gt.find('[s]')]
                        print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                        log.write(
                            f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                    valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        )
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        )
                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (i + 1) % 1e+5 == 0:
                torch.save(
                    model.state_dict(),
                    f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

            if i == opt.num_iter:
                print('end the training')
                sys.exit()
            i += 1
Ejemplo n.º 11
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length + 2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)

    valid_dataset = LmdbDataset(root=opt.test_data, opt=opt)
    test_data_sampler = data_sampler(valid_dataset,
                                     shuffle=False,
                                     distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=False)

    print('-' * 80)

    opt.num_class = len(converter.character)

    ocrModel = ModelV1(opt).to(device)

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model,
                            map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    evalCntr = 0
    fCntr = 0

    c1_s1_input_correct = 0.0
    c1_s1_input_ed_correct = 0.0
    # pdb.set_trace()

    for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader):
        print(vCntr)

        image_input_tensors = image_input_tensors.to(device)
        text_gt, length_gt = converter.encode(
            labels_gt, batch_max_length=opt.batch_max_length)

        with torch.no_grad():
            currBatchSize = image_input_tensors.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] *
                                             image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data,
                                                  preds_size.data)

            else:
                preds = ocrModel(
                    image_input_tensors, text_gt[:, :-1],
                    is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        for trImgCntr in range(image_input_tensors.shape[0]):
            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)

            evalCntr += 1

            fCntr += 1

    avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'),
              'a') as log:
        # training loss and validation loss

        loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}'

        print(loss_log)
        log.write(loss_log + "\n")
Ejemplo n.º 12
0
                    type=int,
                    default=1,
                    help='the number of input channel of Feature extractor')
parser.add_argument('--output_channel',
                    type=int,
                    default=512,
                    help='the number of output channel of Feature extractor')
parser.add_argument('--hidden_size',
                    type=int,
                    default=256,
                    help='the size of the LSTM hidden state')
parser.add_argument('--include_space', type=bool, default=False)

opt = parser.parse_args()

opt.character = []
with open(os.path.join(opt.train_data, 'kr_labels.txt'), 'r') as f:
    lines = f.readlines()
    for line in lines:
        ch = line.strip().split()[1]
        if len(ch) != 1:
            print(f'{ch}s length is greater than 1')
        opt.character.append(ch)

if opt.include_space:
    opt.character.append(' ')

dataset = LmdbDataset('data/train/printed', opt)

for i in range(1, 1000):
    dataset.__getitem__(i)