예제 #1
0
def main(opt, case):
    print("Arguments are : " + str(opt))

    if opt.experiment is None:
        opt.experiment = 'expr'
    os.system('mkdir {0}'.format(opt.experiment))

    # Why do we use this?
    opt.manualSeed = random.randint(1, 10000)  # fix seed
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

        opt.cuda = True
        print('Set CUDA to true.')

    train_dataset = dataset.hwrDataset(mode="train")
    assert train_dataset

    # The shuffle needs to be false when the sizing has been done.

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=False,
                                               num_workers=int(opt.workers),
                                               collate_fn=dataset.alignCollate(
                                                   imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio=True))

    test_dataset = dataset.hwrDataset(mode="test",
                                      transform=dataset.resizeNormalize(
                                          (100, 32)))

    nclass = len(opt.alphabet) + 1
    nc = 1

    criterion = CTCLoss()

    # custom weights initialization called on crnn
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    crnn = crnn_model.CRNN(opt.imgH, nc, nclass, opt.nh)
    crnn.apply(weights_init)

    if opt.cuda and not opt.uses_old_saving:
        crnn.cuda()
        crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
        criterion = criterion.cuda()

    if opt.crnn != '':

        print('Loading pre-trained model from %s' % opt.crnn)
        loaded_model = torch.load(opt.crnn)

        if opt.uses_old_saving:
            print("Assuming model was saved in rudementary fashion")
            crnn.load_state_dict(loaded_model)
            crnn.cuda()

            crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
            criterion = criterion.cuda()
            start_epoch = 0
        else:
            print("Loaded model accuracy: " + str(loaded_model['accuracy']))
            print("Loaded model epoch: " + str(loaded_model['epoch']))
            start_epoch = loaded_model['epoch']
            crnn.load_state_dict(loaded_model['state'])

    # Read this.
    loss_avg = utils.averager()

    # If following the paper's recommendation, using AdaDelta
    if opt.adam:
        optimizer = optim.Adam(crnn.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    elif opt.adadelta:
        optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
    elif opt.adagrad:
        print("Using adagrad")
        optimizer = optim.Adagrad(crnn.parameters(), lr=opt.lr)
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)

    converter = utils.strLabelConverter(opt.alphabet)

    best_val_accuracy = 0

    for epoch in range(start_epoch, opt.niter):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = train_batch(crnn, criterion, optimizer, train_iter, opt,
                               converter)
            loss_avg.add(cost)
            i += 1

            if i % opt.displayInterval == 0:
                print(
                    '[%d/%d][%d/%d] Loss: %f' %
                    (epoch, opt.niter, i, len(train_loader), loss_avg.val()) +
                    " " + case)
                loss_avg.reset()

            if i % opt.valInterval == 0:
                try:
                    val_loss_avg, accuracy = val_batch(crnn, opt, test_dataset,
                                                       converter, criterion)

                    model_state = {
                        'epoch': epoch + 1,
                        'iter': i,
                        'state': crnn.state_dict(),
                        'accuracy': accuracy,
                        'val_loss_avg': val_loss_avg,
                    }
                    utils.save_checkpoint(
                        model_state, accuracy > best_val_accuracy,
                        '{0}/netCRNN_{1}_{2}_{3}.pth'.format(
                            opt.experiment, epoch, i,
                            accuracy), opt.experiment)

                    if accuracy > best_val_accuracy:
                        best_val_accuracy = accuracy

                except Exception as e:
                    print(e)
def val(net, test_dataset, criterion, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False


#    layer_dict = net.state_dict()
#    print(layer_dict['cnn.conv1.weight'])

    net.eval()
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              shuffle=True,
                                              batch_size=opt.batchSize,
                                              num_workers=int(opt.workers),
                                              collate_fn=dataset.alignCollate(
                                                  imgH=32,
                                                  imgW=100,
                                                  keep_ratio=True))
    val_iter = iter(data_loader)

    i = 0
    n = 0
    n_correct = 0
    n_text = 0
    loss_avg = util.averager()

    max_iter = len(data_loader)
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        util.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)

        util.loadData(text, t)
        util.loadData(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        #        preds = preds.squeeze(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        #	print (preds)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for pred, target in zip(sim_preds, cpu_texts):
            if isinstance(target, unicode) is False:
                target = target.decode('utf-8')
            pred_encode, _ = converter.encode(pred)
            target_encode, _ = converter.encode(target)
            t = editdistance.eval(pred_encode, target_encode)
            l = len(target_encode)
            n_correct += t
            n_text += l
            n += 1
    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:opt.n_test_disp]
    for raw_pred, sim_pred, gt in zip(raw_preds, sim_preds, cpu_texts):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, sim_pred, gt))
    len_edit = n_correct / float(n)
    len_text = n_text / float(n)
    norm = 1 - len_edit / len_text
    print('aver editd: %f, norm acc: %f' % (len_edit, norm))
예제 #3
0
    )

train_dataset = dataset.lmdbDataset(root=opt.trainRoot)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=True,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio,
                                               augmentation=opt.use_aug,
                                               noise=opt.use_noise))
test_dataset = dataset.lmdbDataset(root=opt.valRoot,
                                   transform=dataset.resizeNormalize(
                                       (opt.imgW, opt.imgH),
                                       augmentation=False,
                                       noise=False))

nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()

예제 #4
0
def val(net,
        net2,
        net3,
        _dataset,
        _dataset2,
        epoch,
        step,
        criterion,
        max_iter=100):
    logger.info('Start val')
    # for p in crnn.parameters():
    #     p.requires_grad = False
    net.eval()
    net2.eval()
    net3.eval()
    net2.cuda()
    data_loader = torch.utils.data.DataLoader(
        _dataset,
        shuffle=False,
        batch_size=params.batchSize,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    data_loader2 = torch.utils.data.DataLoader(
        _dataset2,
        shuffle=False,
        batch_size=params.batchSize,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.resnet_imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio,
                                        rgb=True))
    val_iter = iter(data_loader)
    val_iter2 = iter(data_loader2)
    i = 0
    n_correct = 0
    loss_avg = utils.averager()
    max_iter = len(data_loader)
    record_dir = log_dir + 'epoch_%d_step_%d_data.txt' % (epoch, step)
    record_dir1 = log_dir + 'epoch_%d_step_%d_data1.txt' % (epoch, step)
    record_dir2 = log_dir + 'epoch_%d_step_%d_data2.txt' % (epoch, step)
    r = 1
    f = open(record_dir, "a")
    f1 = open(record_dir1, "a")
    f2 = open(record_dir2, "a")
    num_label, num_pred = params.total_num, 0

    start = time.time()
    for i in range(max_iter):
        data = val_iter.next()
        data2 = val_iter2.next()
        if i < 6000:
            pass  #continue
        i += 1
        cpu_images, cpu_texts = data
        resnet_images, _ = data2
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        utils.loadData(image2, resnet_images)
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)

        with torch.no_grad():
            n1img = net(image)
            n2img = net2(image2)
            n3img = net3(image)
        preds_size = Variable(torch.IntTensor([n1img.size(0)] * batch_size))

        _, n1 = n1img.max(2)
        _, n2 = n2img.max(2)
        _, n3 = n3img.max(2)
        ind = torch.arange(batch_size)
        _ind = torch.arange(batch_size)
        n1_index = n1.transpose(1, 0).data
        n2_index = n2.transpose(1, 0).data
        n3_index = n3.transpose(1, 0).data
        ind = ind[torch.sum(n1_index != 0, 1) == torch.sum(n2_index != 0, 1)]
        _ind = _ind[
            (torch.sum(n1_index != 0, 1) == torch.sum(n2_index != 0, 1)) *
            (torch.sum(n3_index != 0, 1) == torch.sum(n2_index != 0, 1))]
        for i in ind:
            ind1 = np.arange(n1img.shape[0])
            ind2 = np.arange(n2img.shape[0])
            ind1 = ind1[(n1_index[int(i), :].cpu().numpy().astype(bool) != 0)]
            ind2 = ind2[(n2_index[int(i), :].cpu().numpy().astype(bool) != 0)]
            #n1img[ind1, int(i), :] = (n1img[ind1, int(i), :] + n2img[ind2, int(i), :])/2

            if torch.sum(int(i) == _ind) > 0:
                ind3 = np.arange(n1img.shape[0])
                ind3 = ind3[(n3_index[int(i), :].cpu().numpy().astype(bool) !=
                             0)]
                n1img[ind1, int(i), :] = (
                    n1img[ind1, int(i), :] + n2img[ind2, int(i), :] +
                    n3img[ind3, int(i), :]) / 3  #+ n3img[ind3, int(i), :]
            else:
                n1img[ind1, int(i), :] = (n1img[ind1, int(i), :] +
                                          n2img[ind2, int(i), :]) / 2

        preds = n1img
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        if not isinstance(sim_preds, list):
            sim_preds = [sim_preds]

        for i, pred in enumerate(sim_preds):
            f.write(str(r).zfill(6) + ".jpg " + pred + "\n")
            r += 1
        list_1 = []
        for i in cpu_texts:
            string = i.decode('utf-8', 'strict')
            list_1.append(string)
        for pred, target in zip(sim_preds, list_1):
            if pred == target:
                n_correct += 1
        num_pred += len(sim_preds)

    print("")
    f.close()

    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:params.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, list_1):
        logger.info('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

    logger.info('correct_num: %d' % (n_correct))
    logger.info('Total_num: %d' % (max_iter * params.batchSize))
    accuracy = float(n_correct) / num_pred
    recall = float(n_correct) / num_label
    logger.info(
        'Test loss: %f, accuray: %f, recall: %f, F1 score: %f, Cost : %.4fs per img'
        % (loss_avg.val(), accuracy, recall, 2 * accuracy * recall /
           (accuracy + recall + 1e-2), (time.time() - start) / max_iter))
예제 #5
0
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

train_dataset = dataset.lmdbDataset(root=opt.trainroot)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=True,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(root=opt.valroot,
                                   transform=dataset.resizeNormalize(
                                       (200, 32)))

nclass = len(opt.alphabet) + 1
nc = 1

converter = utils.strLabelConverter(opt.alphabet)
criterion = CTCLoss()


# custom weights initialization called on crnn
def weights_init(m):
    classname = m.__class__.__name__
예제 #6
0
def main(opt):
    print(opt)

    if opt.experiment is None:
        opt.experiment = 'expr'

    os.system('mkdir {0}'.format(opt.experiment))

    # Why is this?
    opt.manualSeed = random.randint(1, 10000)  # fix seed

    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not opt.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    train_dataset = dataset.hwrDataset(mode="train")
    assert train_dataset
    # if not opt.random_sample:
    #     sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
    # else:
    #     sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=opt.batchSize,
                                               shuffle=True,
                                               num_workers=int(opt.workers),
                                               collate_fn=dataset.alignCollate(
                                                   imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio=True))
    # test_dataset = dataset.lmdbDataset(
    #     root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

    test_dataset = dataset.hwrDataset(mode="test",
                                      transform=dataset.resizeNormalize(
                                          (100, 32)))

    nclass = len(opt.alphabet) + 1
    nc = 1

    criterion = CTCLoss()

    # custom weights initialization called on crnn
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    crnn = crnn_model.CRNN(opt.imgH, nc, nclass, opt.nh)
    crnn.apply(weights_init)
    if opt.crnn != '':
        print('loading pretrained model from %s' % opt.crnn)
        crnn.load_state_dict(torch.load(opt.crnn))
    print(crnn)

    # TODO make this central

    image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
    text = torch.IntTensor(opt.batchSize * 5)
    length = torch.IntTensor(opt.batchSize)

    if opt.cuda:
        crnn.cuda()
        crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
        image = image.cuda()
        criterion = criterion.cuda()

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # TODO what is this, read this.
    # loss averager
    loss_avg = utils.averager()

    # Todo default is RMS Prop. I wonder why?
    # setup optimizer

    #Following the paper's recommendation

    opt.adadelta = True
    if opt.adam:
        optimizer = optim.Adam(crnn.parameters(),
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    elif opt.adadelta:
        optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr)

    converter = utils.strLabelConverter(opt.alphabet)

    def val(net, dataset, criterion, max_iter=100):
        print('Start val')

        for p in crnn.parameters():
            p.requires_grad = False

        net.eval()
        data_loader = torch.utils.data.DataLoader(dataset,
                                                  shuffle=True,
                                                  batch_size=opt.batchSize,
                                                  num_workers=int(opt.workers))
        val_iter = iter(data_loader)

        n_correct = 0
        loss_avg = utils.averager()

        max_iter = min(max_iter, len(data_loader))
        for i in range(max_iter):
            print("Is 'i' jumping two values? i == " + str(i))
            data = val_iter.next()
            i += 1
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)

            preds = crnn(image)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg.add(cost)

            _, preds = preds.max(
                2
            )  # todo where is the output size set to 26? Empirically it is.
            # preds = preds.squeeze(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data,
                                         preds_size.data,
                                         raw=False)  # Todo read this.
            for pred, target in zip(sim_preds, cpu_texts):
                if pred == target.lower():
                    n_correct += 1

        raw_preds = converter.decode(preds.data, preds_size.data,
                                     raw=True)[:opt.n_test_disp]
        for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
            print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(max_iter * opt.batchSize)
        print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    for epoch in range(opt.niter):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = train_batch(crnn, criterion, optimizer, train_iter, opt,
                               converter)
            loss_avg.add(cost)
            i += 1

            if i % opt.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, opt.niter, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % opt.valInterval == 0:
                try:
                    val(crnn, test_dataset, criterion)
                except Exception as e:
                    print(e)

            # do checkpointing
            if i % opt.saveInterval == 0:
                torch.save(
                    crnn.state_dict(),
                    '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i))
예제 #7
0
def val(net, valdataset, criterionAttention,criterionCTC, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_sampler = dataset.randomSequentialSampler(valdataset, opt.batchSize)
    data_loader = torch.utils.data.DataLoader(
        valdataset, batch_size=opt.batchSize,
        shuffle=False, sampler=val_sampler,
        num_workers=int(opt.workers),
        collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
   # data_loader = torch.utils.data.DataLoader(
   #     dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        tAttention, lAttention = converterAttention.encode(cpu_texts)
        utils.loadData(textAttention, tAttention)
        utils.loadData(lengthAttention, lAttention)
        tCTC, lCTC = converterCTC.encode(cpu_texts)
        utils.loadData(textCTC, tCTC)
        utils.loadData(lengthCTC, lCTC)
       # print (image)

        if opt.lang:
            predsCTC, predsAttention = crnn(image, lengthAttention, textAttention)
        else:
            predsCTC, predsAttention = crnn(imageAttention, lengthAttention)
        costAttention = criterionAttention(predsAttention, textAttention)
        preds_size = Variable(torch.IntTensor([predsCTC.size(0)] * batch_size))
        costCTC = criterionCTC(predsCTC, textCTC, preds_size, lengthCTC) / batch_size
        loss_avg.add(costAttention)
        loss_avg.add(costCTC.cuda())

        _, predsAttention = predsAttention.max(1)
        predsAttention = predsAttention.view(-1)
        sim_predsAttention = converterAttention.decode(predsAttention.data, lengthAttention.data)
        for pred, target in zip(sim_predsAttention, cpu_texts):
           # target = ''.join(target.split(opt.sep))
            print (pred,target)
            if pred == target:
                n_correct += 1

   # for pred, gt in zip(sim_preds, cpu_texts):
       # gt = ''.join(gt.split(opt.sep))
       # print('%-20s, gt: %-20s' % (pred, gt))

    accuracy = n_correct / float(max_iter * opt.batchSize)
    print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
예제 #8
0
def val(net, valdataset, criterionAttention, criterionCTC, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_sampler = dataset.randomSequentialSampler(valdataset, opt.batchSize)
    data_loader = torch.utils.data.DataLoader(valdataset,
                                              batch_size=opt.batchSize,
                                              shuffle=False,
                                              sampler=val_sampler,
                                              num_workers=int(opt.workers),
                                              collate_fn=dataset.alignCollate(
                                                  imgH=opt.imgH,
                                                  imgW=opt.imgW,
                                                  keep_ratio=opt.keep_ratio))
    # data_loader = torch.utils.data.DataLoader(
    #     dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    n_correctCTC = 0
    n_correctAttention = 0
    distanceCTC = 0
    distanceAttention = 0
    sum_charNum = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        tAttention, lAttention = converterAttention.encode(cpu_texts)
        utils.loadData(textAttention, tAttention)
        utils.loadData(lengthAttention, lAttention)
        tCTC, lCTC = converterCTC.encode(cpu_texts)
        utils.loadData(textCTC, tCTC)
        utils.loadData(lengthCTC, lCTC)
        # print (image)

        if opt.lang:
            predsCTC, predsAttention = crnn(image, lengthAttention,
                                            textAttention)
        else:
            predsCTC, predsAttention = crnn(imageAttention, lengthAttention)
        costAttention = criterionAttention(predsAttention, textAttention)
        preds_size = Variable(torch.IntTensor([predsCTC.size(0)] * batch_size))
        costCTC = criterionCTC(predsCTC, textCTC, preds_size,
                               lengthCTC) / batch_size
        loss_avg.add(costAttention)
        loss_avg.add(costCTC.cuda())

        _, predsAttention = predsAttention.max(1)
        predsAttention = predsAttention.view(-1)
        sim_predsAttention = converterAttention.decode(predsAttention.data,
                                                       lengthAttention.data)

        _, predsCTC = predsCTC.max(2)
        predsCTC = predsCTC.transpose(1, 0).contiguous().view(-1)
        sim_predsCTC = converterCTC.decode(predsCTC.data,
                                           preds_size.data,
                                           raw=False)

        for i, cpu_text in enumerate(cpu_texts):
            gtText = cpu_text.decode('utf-8')
            CTCText = sim_predsCTC[i]
            if isinstance(CTCText, str):
                CTCText = CTCText.decode('utf-8')
            AttentionText = sim_predsAttention[i]
            print('gtText: %s' % gtText)
            print('CTCText: %s' % CTCText)
            print('AttentionText: %s' % AttentionText)
            if gtText == CTCText:
                n_correctCTC += 1
            if gtText == AttentionText:
                n_correctAttention += 1
            distanceCTC += Levenshtein.distance(CTCText, gtText)
            distanceAttention += Levenshtein.distance(AttentionText, gtText)
            sum_charNum = sum_charNum + len(gtText)

    correctCTC_accuracy = n_correctCTC / float(max_iter * batch_size)
    cerCTC = distanceCTC / float(sum_charNum)
    print('Test CERCTC: %f, accuracyCTC: %f' % (cerCTC, correctCTC_accuracy))
    correctAttention_accuracy = n_correctAttention / float(
        max_iter * batch_size)
    cerAttention = distanceAttention / float(sum_charNum)
    print('Test CERAttention: %f, accuricyAttention: %f' %
          (cerAttention, correctAttention_accuracy))
예제 #9
0
        checkpoint = torch.load(model_path)
        start_epoch = checkpoint['epoch']
        # best_pred = checkpoint['best_pred']
        model.load_state_dict(checkpoint['state_dict'])
        # print("=> loaded checkpoint '{}' (epoch {} accuracy {})"
        #       .format(model_path, checkpoint['epoch'], best_pred))

    model.eval()

    train_set = dataset.imageDataset(test_set)  # dataset.graybackNormalize()
    test_loader = torch.utils.data.DataLoader(train_set,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              collate_fn=dataset.alignCollate(
                                                  imgH=imgH,
                                                  imgW=maxW))

    file = open('logger/pred.txt', 'w', encoding='utf-8')
    index = 0
    for i, (cpu_images, _) in enumerate(test_loader):
        bsz = cpu_images.size(0)
        images = cpu_images.cuda()

        predict = model(images)
        predict_len = torch.IntTensor([predict.size(0)] * bsz)
        _, acc = predict.max(2)
        acc = acc.transpose(1, 0).contiguous().view(-1)
        prob, _ = F.softmax(predict, dim=2).max(2)
        probilities = torch.mean(prob, dim=1)
        sim_preds = converter.decode(acc.data, predict_len.data, raw=False)
예제 #10
0
파일: grcnn.py 프로젝트: happog/FudanOCR
def train_grcnn(config_yaml):
    import sys
    sys.path.append('./recognition_model/GRCNN')

    import random
    import torch.backends.cudnn as cudnn
    import torch.optim as optim
    import torch.utils.data
    import numpy as np
    import Levenshtein
    from torch.autograd import Variable
    # from warpctc_pytorch import CTCLoss
    # from GRCNN.utils.Logger import Logger
    from torch.nn import CTCLoss
    import GRCNN.utils.keys as keys
    import GRCNN.utils.util as util
    import dataset
    import GRCNN.models.crann as crann
    import yaml
    import os
    import time

    def adjust_lr(optimizer, base_lr, epoch, step):
        lr = base_lr * (0.1**(epoch // step))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train(model, train_loader, val_loader, criterion, optimizer, opt,
              converter, epoch):
        # Set up training phase.
        interval = int(len(train_loader) / opt['SAVE_FREQ'])
        interval = 100
        print("interval为", interval)
        model.train()

        for i, (cpu_images, cpu_gt) in enumerate(train_loader, 1):
            # print(i)
            # print('iter {} ...'.format(i))
            bsz = cpu_images.size(0)
            text, text_len = converter.encode(cpu_gt)
            # print("做测试  ",text)
            v_images = Variable(cpu_images.cuda())
            v_gt = Variable(text)
            v_gt_len = Variable(text_len)

            model = model.cuda()
            predict = model(v_images)
            predict_len = Variable(torch.IntTensor([predict.size(0)] * bsz))

            loss = criterion(predict, v_gt, predict_len, v_gt_len)
            # logger.scalar_summary('train_loss', loss.data[0], i + epoch * len(train_loader))
            print('train_loss', loss.item())
            # Compute accuracy
            _, acc = predict.max(2)
            acc = acc.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(acc.data, predict_len.data, raw=False)
            n_correct = 0
            for pred, target in zip(sim_preds, cpu_gt):
                # print("做测试",pred,'  ',target)
                if pred.lower() == target.lower():
                    n_correct += 1
            accuracy = n_correct / float(bsz)

            # logger.scalar_summary('train_accuray', accuracy, i + epoch * len(train_loader))

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % interval == 0 and i > 0:
                print('Training @ Epoch: [{0}][{1}/{2}]; Train Accuracy:{3}'.
                      format(epoch, i, len(train_loader), accuracy))
                val(model, val_loader, criterion, converter, epoch,
                    i + epoch * len(train_loader), False)
                model.train()
                freq = int(i / interval)
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '{0}/crann_{1}_{2}.pth'.format(opt['SAVE_PATH'], epoch,
                                                      freq))

    def val(model, ds_loader, criterion, converter, epoch, iteration, valonly):
        print('Start validating on epoch:{0}/iter:{1}...'.format(
            epoch, iteration))
        # print('len   ',len(ds_loader))
        model.eval()
        ave_loss = 0.0
        ave_accuracy = 0.0
        err_sim = []
        err_gt = []
        distance = 0
        length = 0
        with torch.no_grad():
            for i, (cpu_images, cpu_gt) in enumerate(ds_loader):
                # print(i)
                bsz = cpu_images.size(0)
                text, text_len = converter.encode(cpu_gt)
                v_Images = Variable(cpu_images.cuda())
                v_gt = Variable(text)
                v_gt_len = Variable(text_len)

                predict = model(v_Images)
                predict_len = Variable(torch.IntTensor([predict.size(0)] *
                                                       bsz))
                loss = criterion(predict, v_gt, predict_len, v_gt_len)
                ave_loss += loss.item()

                # Compute accuracy
                _, acc = predict.max(2)
                acc = acc.transpose(1, 0).contiguous().view(-1)
                sim_preds = converter.decode(acc.data,
                                             predict_len.data,
                                             raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, cpu_gt):
                    length += len(target)
                    if pred.lower() == target.lower():
                        n_correct += 1.0
                    else:
                        err_sim.append(pred)
                        err_gt.append(target)
                ave_accuracy += n_correct / float(bsz)
            for pred, gt in zip(err_sim, err_gt):
                print('pred: %-20s, gt: %-20s' % (pred, gt))
                distance += Levenshtein.distance(pred, gt)
            # print("The Levenshtein distance is:",distance)
            print("The average Levenshtein distance is:", distance / length)
            if not valonly:
                pass
                # logger.scalar_summary('validation_loss', ave_loss / len(ds_loader), iteration)
                #logger.scalar_summary('validation_accuracy', ave_accuracy / len(ds_loader), iteration)
                # logger.scalar_summary('Ave_Levenshtein_distance', distance / length, iteration)

            f = open('./grcnn_9000k.txt', 'a+')

            print(
                'Testing Accuracy:{0}, Testing Loss:{1} @ Epoch{2}, Iteration{3}'
                .format(ave_accuracy / len(ds_loader),
                        ave_loss / len(ds_loader), epoch, iteration))
            f.write(
                'Testing Accuracy:{0}, Testing Loss:{1} @ Epoch{2}, Iteration{3}\n'
                .format(ave_accuracy / len(ds_loader),
                        ave_loss / len(ds_loader), epoch, iteration))
            f.close()

    def save_checkpoint(state, file_name):
        # time.sleep(0.01)
        # torch.save(state, file_name)
        try:
            time.sleep(0.01)
            torch.save(state, file_name)
        except RuntimeError:
            print("RuntimeError")
            pass

    '''
    Training/Finetune CNN_RNN_Attention Model.
    '''
    #### Load config settings. ####
    f = open(config_yaml, encoding='utf-8')
    opt = yaml.load(f)
    # if os.path.isdir(opt['LOGGER_PATH']) == False:
    #     os.mkdir(opt['LOGGER_PATH'])
    # logger = Logger(opt['LOGGER_PATH'])
    if os.path.isdir(opt['SAVE_PATH']) == False:
        os.system('mkdir -p {0}'.format(opt['SAVE_PATH']))
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True

    #### Set up DataLoader. ####
    train_cfg = opt['TRAIN']
    ds_cfg = train_cfg['DATA_SOURCE']
    print('Building up dataset:{}'.format(ds_cfg['TYPE']))
    if ds_cfg['TYPE'] == 'SYN_DATA':
        text_gen = util.TextGenerator(ds_cfg['GEN_SET'], ds_cfg['GEN_LEN'])
        ds_train = dataset.synthDataset(ds_cfg['FONT_ROOT'],
                                        ds_cfg['FONT_SIZE'], text_gen)
    elif ds_cfg['TYPE'] == 'IMG_DATA':

        from tools.dataset_lmdb import lmdbDataset
        '''
        这里可以进行修改
        '''
        # ds_train = lmdbDataset(root='/home/cjy/syn90_train_100000data_lmdb',
        #                                          transform=None)

        ds_train = dataset.trainDataset(
            ds_cfg['IMG_ROOT'], ds_cfg['TRAIN_SET'],
            transform=None)  # dataset.graybackNormalize()
    assert ds_train
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=train_cfg['BATCH_SIZE'],
        shuffle=True,
        sampler=None,
        num_workers=opt['WORKERS'],
        collate_fn=dataset.alignCollate(imgH=train_cfg['IMG_H'],
                                        imgW=train_cfg['MAX_W']))

    val_cfg = opt['VALIDATION']
    '''
    这里也可以进行修改
    '''
    # ds_val = lmdbDataset(root='/home/cjy/ic15_test_lmdb',
    #                                    transform=None)
    ds_val = dataset.testDataset(val_cfg['IMG_ROOT'],
                                 val_cfg['VAL_SET'],
                                 transform=None)  # dataset.graybackNormalize()
    assert ds_val
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=opt['WORKERS'],
                                             collate_fn=dataset.alignCollate(
                                                 imgH=train_cfg['IMG_H'],
                                                 imgW=train_cfg['MAX_W']))

    #### Model construction and Initialization. ####
    alphabet = keys.alphabet
    nClass = len(alphabet) + 1

    if opt['N_GPU'] > 1:
        opt['RNN']['multi_gpu'] = True
    else:
        opt['RNN']['multi_gpu'] = False
    model = crann.CRANN(opt, nClass)
    # print(model)

    #### Train/Val the model. ####
    converter = util.strLabelConverter(alphabet)
    # from warpctc_pytorch import CTCLoss
    criterion = CTCLoss()
    if opt['CUDA']:
        model.cuda()
        criterion.cuda()

    if opt['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=opt['TRAIN']['LR'])
    elif opt['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt['TRAIN']['LR'],
                               betas=(opt['TRAIN']['BETA1'], 0.999))
    elif opt['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt['TRAIN']['LR'])
    else:
        optimizer = optim.Adadelta(model.parameters(), lr=opt['TRAIN']['LR'])

    start_epoch = 0
    if opt['VAL_ONLY']:
        print('=>loading pretrained model from %s for val only.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        model.load_state_dict(checkpoint['state_dict'])
        val(model, val_loader, criterion, converter, 0, 0, True)
    elif opt['FINETUNE']:
        print('=>loading pretrained model from %s for finetuen.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        # model.load_state_dict(checkpoint['state_dict'])
        model_dict = model.state_dict()
        # print(model_dict.keys())
        cnn_dict = {
            "cnn." + k: v
            for k, v in checkpoint.items() if "cnn." + k in model_dict
        }
        model_dict.update(cnn_dict)
        model.load_state_dict(model_dict)
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
    elif opt['RESUME']:
        print('=>loading checkpoint from %s for resume training.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        start_epoch = checkpoint['epoch'] + 1
        print('resume from epoch:{}'.format(start_epoch))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
    else:
        print('train from scratch.')
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
예제 #11
0
def val(net,
        _dataset1,
        _dataset2,
        _dataset3,
        epoch,
        step,
        criterion,
        max_iter=100):
    logger.info('Start val')
    # for p in crnn.parameters():
    #     p.requires_grad = False
    net.eval()
    data_loader1 = torch.utils.data.DataLoader(
        _dataset1,
        shuffle=False,
        batch_size=params.batchSize,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    data_loader2 = torch.utils.data.DataLoader(
        _dataset2,
        shuffle=False,
        batch_size=params.batchSize,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    data_loader3 = torch.utils.data.DataLoader(
        _dataset3,
        shuffle=False,
        batch_size=params.batchSize,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    val_iter = iter(data_loader1)
    val_iter2 = iter(data_loader2)
    val_iter3 = iter(data_loader3)
    i = 0
    n_correct = 0
    loss_avg = utils.averager()
    max_iter = len(data_loader1)
    record_dir = log_dir + 'epoch_%d_step_%d_data.txt' % (epoch, step)
    r = 1
    f = open(record_dir, "a")
    num_label, num_pred = params.total_num, 0

    start = time.time()
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)
        data2 = val_iter2.next()
        cpu_images2, _ = data2
        utils.loadData(image2, cpu_images2)
        data3 = val_iter3.next()
        cpu_images3, _ = data3
        utils.loadData(image3, cpu_images3)
        with torch.no_grad():
            preds = torch.mean(
                torch.cat([
                    torch.unsqueeze(crnn(image), 0),
                    torch.unsqueeze(crnn(image2), 0),
                    torch.unsqueeze(crnn(image3), 0)
                ], 0), 0)
        print('preds: ', preds.shape)
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        if not isinstance(sim_preds, list):
            sim_preds = [sim_preds]
        for pred in sim_preds:
            f.write(str(r).zfill(6) + ".jpg " + pred + "\n")
            r += 1
        list_1 = []
        for i in cpu_texts:
            string = i.decode('utf-8', 'strict')
            list_1.append(string)
        for pred, target in zip(sim_preds, list_1):
            if pred == target:
                n_correct += 1

        num_pred += len(sim_preds)

    print("")
    f.close()

    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:params.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, list_1):
        logger.info('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

    logger.info('correct_num: %d' % (n_correct))
    logger.info('Total_num: %d' % (max_iter * params.batchSize))
    accuracy = float(n_correct) / num_pred
    recall = float(n_correct) / num_label
    logger.info(
        'Test loss: %f, accuray: %f, recall: %f, F1 score: %f, Cost : %.4fs per img'
        % (loss_avg.val(), accuracy, recall, 2 * accuracy * recall /
           (accuracy + recall + 1e-2), (time.time() - start) / max_iter))
예제 #12
0
    assert train_dataset
    if rc_params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  rc_params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*160
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=rc_params.batchSize,
        shuffle=False,
        sampler=sampler,
        num_workers=int(rc_params.workers),
        collate_fn=dataset.alignCollate(imgH=rc_params.imgH,
                                        imgW=rc_params.imgW,
                                        keep_ratio=rc_params.keep_ratio,
                                        rgb=True))

    # read test set
    # images will be resize to 32*160
    test_dataset = dataset.lmdbDataset(root=opt.valroot, rgb=True)

    nclass = len(rc_params.alphabet) + 1
    nc = 1

    converter = utils.strLabelConverter(rc_params.alphabet)
    # criterion = CTCLoss(size_average=True, length_average=True)
    criterion = CTCLoss(size_average=True)

    # cnn and rnn
    image = torch.FloatTensor(rc_params.batchSize, 3, rc_params.imgH,
예제 #13
0
    if params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  params.batchSize)
    else:
        sampler = None
    os.environ["CUDA_VISIBLE_DEVICES"] = opt.GPU_ID
    # images will be resize to 32*160
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=params.batchSize,
        shuffle=False,
        sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio,
                                        with_resize=params.with_resize,
                                        ratio_ranges=params.ratio_range))
    # read test set
    # images will be resize to 32*160
    test_dataset = dataset.lmdbDataset(root=opt.valroot)

    nclass = len(params.alphabet) + 1
    nc = 1

    converter = utils.strLabelConverter(params.alphabet)
    criterion = CTCLoss()

    # cnn and rnn
    image = torch.FloatTensor(params.batchSize, 1, params.imgH, params.imgH)
    text = torch.IntTensor(params.batchSize * 5)
예제 #14
0
파일: test.py 프로젝트: Narcissuscyn/OCR
def val(net, val_dataset, criterion, max_iter=100):
    print('Start val')

    for p in model.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(val_dataset,
                                              shuffle=False,
                                              batch_size=opt.batchSize,
                                              num_workers=int(opt.workers),
                                              collate_fn=dataset.alignCollate(
                                                  imgH=opt.imgH,
                                                  imgW=opt.imgW,
                                                  keep_ratio=opt.keep_ratio))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    max_iter = max(max_iter, len(data_loader))

    img_num = 1
    str_line = ''
    dst_file = open('/home/new/File/OCR/crnn/mycode/test_rst.txt', 'w')
    dst_root = '/home/new/File/OCR/crnn/mycode/test_result/'
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)

        preds = model(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)

        for pred, gt in zip(sim_preds, cpu_texts):
            #with subcribe
            # if int(gt[5:-4])==img_num:
            #     str_line+=pred
            # else:
            #
            #     # dst_ite_pth=dst_root+str(img_num).zfill(7)+'.txt'
            #     # f=open(dst_ite_pth,'w+')
            #     # f.write(str_line)
            #     # f.close()
            #     dst_file.write(gt+' '+str_line+'\n')
            #     str_line=pred
            #     img_num+=1

            #without subcribe
            dst_file.write(gt + ' ' + pred + '\n')

            print('pred:%-20s, gt: %-20s' % (pred, gt))
    dst_file.close()
예제 #15
0
        os.mkdir('./expr')

    # read s_train set
    s_train_dataset = dataset.lmdbDataset(root=params.s_train_data)
    assert s_train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(s_train_dataset, params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*96
    s_train_loader = torch.utils.data.DataLoader(
        s_train_dataset, batch_size=params.batchSize,
        shuffle=True, sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=96, keep_ratio=params.keep_ratio))

    # read s_train set
    m_train_dataset = dataset.lmdbDataset(root=params.m_train_data)
    assert m_train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(m_train_dataset, params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*384
    m_train_loader = torch.utils.data.DataLoader(
        m_train_dataset, batch_size=params.batchSize,
        shuffle=True, sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=384, keep_ratio=params.keep_ratio))
예제 #16
0
    train_dataset = dataset.CCPD(trainRoot,
                                 requires_interpret=not opt.no_need_interpret)
    print('Training with custom dataset')
    focal_alpha = True
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio,
                                               requires_prob=focal_alpha))
if len(glob.glob(os.path.join(opt.valRoot, '*.mdb'))):
    test_dataset = dataset.lmdbDataset(root=opt.valRoot,
                                       transform=dataset.resizeNormalize(
                                           (100, 32)))
    print('Testing with lmdb dataset')
else:
    test_dataset = dataset.CCPD(opt.valRoot,
                                transform=dataset.resizeNormalize((100, 32)))
    print('Testing with custom dataset')

alphabet = utils.generate_alphabet()
print('Alphabet:', alphabet, '\n', len(alphabet))
nclass = len(alphabet) + 1
예제 #17
0
def main(config_yaml):
    '''
    Training/Finetune CNN_RNN_Attention Model.
    '''
    #### Load config settings. ####
    f = open(config_yaml)
    opt = yaml.load(f)
    if os.path.isdir(opt['LOGGER_PATH']) == False:
        os.mkdir(opt['LOGGER_PATH'])
    logger = Logger(opt['LOGGER_PATH'])
    if os.path.isdir(opt['SAVE_PATH']) == False:
        os.system('mkdir -p {0}'.format(opt['SAVE_PATH']))
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True

    #### Set up DataLoader. ####
    train_cfg = opt['TRAIN']
    ds_cfg = train_cfg['DATA_SOURCE']
    print('Building up dataset:{}'.format(ds_cfg['TYPE']))
    if ds_cfg['TYPE'] == 'SYN_DATA':
        text_gen = util.TextGenerator(ds_cfg['GEN_SET'], ds_cfg['GEN_LEN'])
        ds_train = dataset.synthDataset(ds_cfg['FONT_ROOT'],
                                        ds_cfg['FONT_SIZE'], text_gen)
    elif ds_cfg['TYPE'] == 'IMG_DATA':
        ds_train = dataset.trainDataset(
            ds_cfg['IMG_ROOT'], ds_cfg['TRAIN_SET'],
            transform=None)  #dataset.graybackNormalize()
    assert ds_train
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=train_cfg['BATCH_SIZE'],
        shuffle=True,
        sampler=None,
        num_workers=opt['WORKERS'],
        collate_fn=dataset.alignCollate(imgH=train_cfg['IMG_H'],
                                        imgW=train_cfg['MAX_W']))

    val_cfg = opt['VALIDATION']
    ds_val = dataset.testDataset(val_cfg['IMG_ROOT'],
                                 val_cfg['VAL_SET'],
                                 transform=None)  #dataset.graybackNormalize()
    assert ds_val
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=16,
                                             shuffle=False,
                                             num_workers=opt['WORKERS'],
                                             collate_fn=dataset.alignCollate(
                                                 imgH=train_cfg['IMG_H'],
                                                 imgW=train_cfg['MAX_W']))

    #### Model construction and Initialization. ####
    alphabet = keys.alphabet
    nClass = len(alphabet) + 1

    if opt['N_GPU'] > 1:
        opt['RNN']['multi_gpu'] = True
    else:
        opt['RNN']['multi_gpu'] = False
    model = crann.CRANN(opt, nClass)
    #print(model)

    #### Train/Val the model. ####
    converter = util.strLabelConverter(alphabet)
    criterion = CTCLoss()
    if opt['CUDA']:
        model.cuda()
        criterion.cuda()

    if opt['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=opt['TRAIN']['LR'])
    elif opt['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt['TRAIN']['LR'],
                               betas=(opt['TRAIN']['BETA1'], 0.999))
    elif opt['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt['TRAIN']['LR'])
    else:
        optimizer = optim.Adadelta(model.parameters(), lr=opt['TRAIN']['LR'])

    start_epoch = 0
    if opt['VAL_ONLY']:
        print('=>loading pretrained model from %s for val only.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        model.load_state_dict(checkpoint['state_dict'])
        val(model, val_loader, criterion, converter, 0, 0, logger, True)
    elif opt['FINETUNE']:
        print('=>loading pretrained model from %s for finetuen.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        #model.load_state_dict(checkpoint['state_dict'])
        model_dict = model.state_dict()
        #print(model_dict.keys())
        cnn_dict = {
            "cnn." + k: v
            for k, v in checkpoint.items() if "cnn." + k in model_dict
        }
        model_dict.update(cnn_dict)
        model.load_state_dict(model_dict)
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)
    elif opt['RESUME']:
        print('=>loading checkpoint from %s for resume training.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        start_epoch = checkpoint['epoch'] + 1
        print('resume from epoch:{}'.format(start_epoch))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)
    else:
        print('train from scratch.')
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)
예제 #18
0
파일: train.py 프로젝트: yiwangchunyu/CVCR
def main(arg):
    print(arg)
    train_dataset = dataset.lmdbDataset(
        path=arg.train_root,
        # transform=dataset.resizeNormalize((imgW,imgH)),
    )
    test_dataset = dataset.lmdbDataset(
        path=arg.test_root,
        # transform=dataset.resizeNormalize((arg.imgW,arg.imgH)),
    )
    d = test_dataset.__getitem__(0)
    l = test_dataset.__len__()
    train_loader = DataLoader(train_dataset,
                              num_workers=arg.num_workers,
                              batch_size=arg.batch_size,
                              collate_fn=dataset.alignCollate(
                                  imgH=arg.imgH,
                                  imgW=arg.imgW,
                                  keep_ratio=arg.keep_ratio),
                              shuffle=True,
                              drop_last=True)

    criterion = CTCLoss()
    converter = utils.Converter(arg.num_class)
    crnn = CRNN(imgH=arg.imgH, nc=3, nclass=arg.num_class + 1, nh=256)

    # custom weights initialization called on crnn
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    crnn.apply(weights_init)
    print(crnn)

    image = torch.FloatTensor(arg.batch_size, 3, arg.imgH, arg.imgW)
    text = torch.IntTensor(arg.batch_size * 5)
    length = torch.IntTensor(arg.batch_size)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # loss averager
    loss_avg = utils.averager()

    # setup optimizer
    if arg.opt == 'adam':
        optimizer = optim.Adam(crnn.parameters(), 0.01, betas=(0.5, 0.999))
    elif arg.opt == 'adadelta':
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), 0.01)

    for epoch in range(arg.n_epoch):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            data = train_iter.next()
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            text_labels, l = converter.encode(cpu_texts)
            utils.loadData(text, text_labels)
            utils.loadData(length, l)

            preds = crnn(image)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            crnn.zero_grad()
            cost.backward()
            optimizer.step()

            loss_avg.add(cost)
            i += 1

            if i % arg.displayInterval == 0:
                print(
                    '[%d/%d][%d/%d] Loss: %f' %
                    (epoch, arg.n_epoch, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % arg.testInterval == 0:
                test(arg, crnn, test_dataset, criterion, image, text, length)

            # do checkpointing
            if i % arg.saveInterval == 0:
                name = '{0}/netCRNN_{1}_{2}_{3}_{4}.pth'.format(
                    arg.model_dir, arg.num_class, arg.type, epoch, i)
                torch.save(crnn.state_dict(), name)
                print('model saved at ', name)
    torch.save(
        crnn.state_dict(),
        '{0}/netCRNN_{1}_{2}.pth'.format(arg.model_dir, arg.num_class,
                                         arg.type))
예제 #19
0
import models.crnn as crnn
import dataset
import utils

batchSize = 64

outputPath = '/image/.Eric/lmdb_train'

train_dataset = dataset.lmdbDataset(root=outputPath)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batchSize,
                                           shuffle=True,
                                           num_workers=4,
                                           collate_fn=dataset.alignCollate(
                                               imgH=32, imgW=320))

alphabet = '零壹貳參肆伍陸柒捌玖拾佰仟萬億兆元整'

nclass = len(alphabet) + 1
nc = 1
criterion = CTCLoss()

converter = utils.strLabelConverter(alphabet)
criterion = CTCLoss()

model = crnn.CRNN(32, 1, 19, 256)
model = model.cuda()
model.load_state_dict(torch.load("/image/.Eric/modelResult/crnnV2.pkl"))

image = torch.FloatTensor(batchSize, 1, 32, 32)
예제 #20
0
    assert train_dataset
    if params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*160
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=params.batchSize,
        shuffle=False,
        sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))
    train_iter = iter(train_loader)
    for i in range(5000):
        print(train_iter.next()[0].shape)
    # read test set
    # images will be resize to 32*160
    test_dataset = dataset.lmdbDataset(root=opt.valroot, rgb=params.rgb)

    nclass = len(params.alphabet) + 1
    nc = 1

    converter = utils.strLabelConverter(params.alphabet)
    criterion = CTCLoss(size_average=False, length_average=False)

    # cnn and rnn
예제 #21
0
cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

train_dataset = dataset.lmdbDataset(root=opt.trainroot)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=opt.batchSize,
    shuffle=True, sampler=sampler,
    num_workers=int(opt.workers),
    collate_fn=dataset.alignCollate(imgH=opt.imgH, keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(
    root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

ngpu = int(opt.ngpu)
nh = int(opt.nh)
alphabet = opt.alphabet
nclass = len(alphabet) + 1
nc = 1

converter = utils.strLabelConverter(alphabet)
criterion = CTCLoss()


# custom weights initialization called on crnn
def weights_init(m):
예제 #22
0
indices = list(range(dataset_size))
split = int(np.floor(0.05 * dataset_size))
np.random.seed(manualSeed)
np.random.shuffle(indices)
train_indices, val_indices = indices, indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(full_dataset,
                                           batch_size=batchSize,
                                           sampler=train_sampler,
                                           num_workers=int(workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=imgH,
                                               imgW=imgW,
                                               keep_ratio=keep_ratio))

test_loader = torch.utils.data.DataLoader(full_dataset,
                                          batch_size=batchSize,
                                          sampler=valid_sampler,
                                          num_workers=int(workers),
                                          collate_fn=dataset.alignCollate(
                                              imgH=imgH,
                                              imgW=imgW,
                                              keep_ratio=keep_ratio))

# test_dataset = dataset.lmdbDataset(
#     root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

nclass = len(alphabet) + 1
예제 #23
0
파일: grcnn.py 프로젝트: happog/FudanOCR
def demo_grcnn(config_yaml):

    import sys
    sys.path.append('./recognition_model/GRCNN')

    import torch
    import os
    from utils import keys
    from models import crann
    import dataset
    from utils import util
    import torch.nn.functional as F
    import io
    import yaml
    import tools.utils as utils
    import tools.dataset_lmdb as dataset_lmdb
    import torchvision.transforms as transforms
    import lmdb
    import cv2

    # 需要在配置文件里体现
    # opt.model_path = 'checkpoints/grcnn_art/crann_11_1.pth'
    # batch_size = 16
    #imgH = 32
    # maxW = 100
    # num_workers = 4
    # cnn_model = 'grcnn'
    # rnn_model = 'compositelstm'
    # n_In = 512
    # n_Hidden = 256
    # test_set = '../art_test.txt'

    # from yacs.config import CfgNode as CN
    #
    # def read_config_file(config_file):
    #     # 用yaml重构配置文件
    #     f = open(config_file)
    #     opt = CN.load_cfg(f)
    #     return opt
    #
    # opt = read_config_file(config_file)

    f = open(config_yaml, encoding='utf-8')
    opt = yaml.load(f)

    alphabet = keys.alphabet
    nClass = len(alphabet) + 1
    converter = util.strLabelConverter(alphabet)

    model = crann.CRANN(opt, nClass).cuda()
    if os.path.isfile(opt['DEMO']['model_path']):
        print("=> loading checkpoint '{}'".format(opt['DEMO']['model_path']))
        checkpoint = torch.load(opt['DEMO']['model_path'])
        start_epoch = checkpoint['epoch']
        # best_pred = checkpoint['best_pred']
        model.load_state_dict(checkpoint['state_dict'])
        # print("=> loaded checkpoint '{}' (epoch {} accuracy {})"
        #       .format(opt.model_path, checkpoint['epoch'], best_pred))

    model.eval()

    # root, mappinggit

    train_set = dataset_lmdb.lmdbDataset(opt['DEMO']['test_set_lmdb'])

    # train_set = dataset.testDataset(opt['test_set'])  # dataset.graybackNormalize()
    test_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=opt['TRAIN']['BATCH_SIZE'],
        shuffle=False,
        num_workers=opt['WORKERS'],
        collate_fn=dataset.alignCollate(imgH=opt['TRAIN']['IMG_H'],
                                        imgW=opt['TRAIN']['MAX_W']))

    file = open('./pred.txt', 'w', encoding='utf-8')

    try:
        import shutil
        shutil.rmtree('./GRCNN_DEMO')
        # os.makedirs('./MORAN_DEMO')
    except:
        pass
    os.makedirs('./GRCNN_DEMO')
    record_file = open('./GRCNN_DEMO/result.txt', 'a', encoding='utf-8')

    index = 0
    for i, (cpu_images, targets) in enumerate(test_loader):

        # 还可以再改造一下

        bsz = cpu_images.size(0)
        images = cpu_images.cuda()

        predict = model(images)
        predict_len = torch.IntTensor([predict.size(0)] * bsz)
        _, acc = predict.max(2)
        acc = acc.transpose(1, 0).contiguous().view(-1)
        prob, _ = F.softmax(predict, dim=2).max(2)
        probilities = torch.mean(prob, dim=1)
        sim_preds = converter.decode(acc.data, predict_len.data, raw=False)

        cnt = 0
        for probility, pred, target in zip(probilities, sim_preds, targets):
            index += 1
            img_key = 'gt_%d' % index
            file.write('%s:\t\t\t\t%.3f%%\t%-20s\n' %
                       (img_key, probility.item() * 100, pred))

            # print("调试开始")
            # print(images[0].size)
            # print("调试结束")

            # cv2.imwrite('./GRCNN_DEMO/' + str(index) + '.jpg', (images[cnt].cpu().numpy() + 1.0) * 128)
            record_file.write('./GRCNN_DEMO/' + str(index) + '.jpg' + '  ' +
                              pred + '   ' + target + ' \n')
            cnt += 1

    file.close()