Ejemplo n.º 1
0
def data_loader():
    # train
    train_dataset = dataset.lmdbDataset(root=args.trainroot,
                                        transform=dataset.customResize())
    assert train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  params.batchSize)
    else:
        sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
            shuffle=True, sampler=sampler, num_workers=int(params.workers), \
            collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW))

    # val
    val_dataset = dataset.lmdbDataset(root=args.valroot,
                                      transform=dataset.resizeNormalize(
                                          (params.imgW, params.imgH)))
    assert val_dataset
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             shuffle=True,
                                             batch_size=params.batchSize,
                                             num_workers=int(params.workers))

    return train_loader, val_loader
Ejemplo n.º 2
0
def val(net, test_dataset, criterion, max_iter=2):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=opt.batchSize,
        num_workers=int(opt.workers),
        sampler=dataset.randomSequentialSampler(test_dataset, opt.batchSize),
        collate_fn=dataset.alignCollate(imgH=opt.imgH,
                                        imgW=opt.imgW,
                                        keep_ratio=opt.keep_ratio))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()
    test_distance = 0
    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        if ifUnicode:
            cpu_texts = [clean_txt(tx.decode('utf-8')) for tx in cpu_texts]
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        # preds = preds.squeeze(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred.strip() == target.strip():
                n_correct += 1
            # print(distance.levenshtein(pred.strip(), target.strip()))
            test_distance += distance.nlevenshtein(pred.strip(),
                                                   target.strip(),
                                                   method=2)
    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:opt.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):

        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
    accuracy = n_correct / float(max_iter * opt.batchSize)
    test_distance = test_distance / float(max_iter * opt.batchSize)
    testLoss = loss_avg.val()
    #print('Test loss: %f, accuray: %f' % (testLoss, accuracy))
    return testLoss, accuracy, test_distance
Ejemplo n.º 3
0
def data_loader():
    # train
    train_transform = ImgAugTransform()

    train_datasets = []
    for train_root in params.train_roots:
        train_dataset = dataset.lmdbDataset(root=train_root, transform=train_transform)
        train_datasets.append(train_dataset)
    train_dataset = torch.utils.data.ConcatDataset(train_datasets)

    assert train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset, params.batchSize)
    else:
        sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
            shuffle=True, sampler=sampler, num_workers=int(params.workers), \
            collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))
    
    # val
    val_dataset_list = []
    for val_root in params.val_roots:
        val_dataset = dataset.lmdbDataset(root=val_root, transform=dataset.processing_image((params.imgW, params.imgH)))
        val_dataset_list.append(val_dataset)

    val_dataset = torch.utils.data.ConcatDataset(val_dataset_list)
    assert val_dataset
    val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=params.batchSize, num_workers=int(params.workers))
    
    return train_loader, val_loader, train_dataset, val_dataset
Ejemplo n.º 4
0
def data_loader():
    # train
    transform = torchvision.transforms.Compose(
        [ImgAugTransform(), GridDistortion(prob=0.65)])
    train_dataset = dataset.lmdbDataset(root=args.trainroot,
                                        transform=transform)
    assert train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  params.batchSize)
    else:
        sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
            shuffle=True, sampler=sampler, num_workers=int(params.workers), \
            collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio))

    # val
    transform = torchvision.transforms.Compose(
        [dataset.resizeNormalize((params.imgW, params.imgH))])
    val_dataset = dataset.lmdbDataset(root=args.valroot, transform=transform)
    assert val_dataset
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             shuffle=True,
                                             batch_size=params.batchSize,
                                             num_workers=int(params.workers))

    return train_loader, val_loader
Ejemplo n.º 5
0
def initTrainDataLoader():
    print_msg("开始加载训练集lmdb:{}".format(dataset_dir))
    train_dataset = dataset.lmdbDataset(root=dataset_dir)
    assert train_dataset
    print_msg("加载训练集lmdb 成功")

    if opt.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
    else:
        sampler = None
    loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batchSize,
        shuffle=True, sampler=sampler,
        num_workers=int(opt.workers),
        collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
    return loader
Ejemplo n.º 6
0
def data_loader():
    # train
    train_dataset = dataset.lmdbDataset(root=args.trainroot,
                                        transform=Compose([
                                            Rotate(p=0.5, limit=(-15, 15), border_mode=cv2.BORDER_CONSTANT, value=255),
                                            CustomPiecewiseAffineTransform(),
                                        ]))
    assert train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset, params.batchSize)
    else:
        sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \
                                               shuffle=True, sampler=sampler, num_workers=int(params.workers), \
                                               collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW,
                                                                               keep_ratio=params.keep_ratio))

    # val
    val_dataset = dataset.lmdbDataset(root=args.valroot, transform=dataset.resizeNormalize((params.imgW, params.imgH)))
    assert val_dataset
    val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=params.batchSize,
                                             num_workers=int(params.workers))

    return train_loader, val_loader
Ejemplo n.º 7
0
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print(
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

train_dataset = dataset.lmdbDataset(root=opt.trainroot)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=True,
                                           sampler=sampler,
                                           num_workers=int(opt.workers),
                                           collate_fn=dataset.alignCollate(
                                               imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(root=opt.valroot,
                                   transform=dataset.resizeNormalize(
                                       (100, 32)))
Ejemplo n.º 8
0
def val(net, valdataset, criterionAttention,criterionCTC, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_sampler = dataset.randomSequentialSampler(valdataset, opt.batchSize)
    data_loader = torch.utils.data.DataLoader(
        valdataset, batch_size=opt.batchSize,
        shuffle=False, sampler=val_sampler,
        num_workers=int(opt.workers),
        collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
   # data_loader = torch.utils.data.DataLoader(
   #     dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        tAttention, lAttention = converterAttention.encode(cpu_texts)
        utils.loadData(textAttention, tAttention)
        utils.loadData(lengthAttention, lAttention)
        tCTC, lCTC = converterCTC.encode(cpu_texts)
        utils.loadData(textCTC, tCTC)
        utils.loadData(lengthCTC, lCTC)
       # print (image)

        if opt.lang:
            predsCTC, predsAttention = crnn(image, lengthAttention, textAttention)
        else:
            predsCTC, predsAttention = crnn(imageAttention, lengthAttention)
        costAttention = criterionAttention(predsAttention, textAttention)
        preds_size = Variable(torch.IntTensor([predsCTC.size(0)] * batch_size))
        costCTC = criterionCTC(predsCTC, textCTC, preds_size, lengthCTC) / batch_size
        loss_avg.add(costAttention)
        loss_avg.add(costCTC.cuda())

        _, predsAttention = predsAttention.max(1)
        predsAttention = predsAttention.view(-1)
        sim_predsAttention = converterAttention.decode(predsAttention.data, lengthAttention.data)
        for pred, target in zip(sim_predsAttention, cpu_texts):
            #regText = pred.decode('utf-8')
            regText = pred#type of pred is unicode, do not need convert
            gtText = target.decode('utf-8')#convert str(label type)to unicode
            print (regText,gtText)
            if regText == gtText:
                print("correct")
                print (regText,gtText)
                n_correct += 1

   # for pred, gt in zip(sim_preds, cpu_texts):
       # gt = ''.join(gt.split(opt.sep))
       # print('%-20s, gt: %-20s' % (pred, gt))

    accuracy = n_correct / float(max_iter * opt.batchSize)
    print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
Ejemplo n.º 9
0
def val(net, valdataset, criterionAttention, criterionCTC, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_sampler = dataset.randomSequentialSampler(valdataset, opt.batchSize)
    data_loader = torch.utils.data.DataLoader(valdataset,
                                              batch_size=opt.batchSize,
                                              shuffle=False,
                                              sampler=val_sampler,
                                              num_workers=int(opt.workers),
                                              collate_fn=dataset.alignCollate(
                                                  imgH=opt.imgH,
                                                  imgW=opt.imgW,
                                                  keep_ratio=opt.keep_ratio))
    # data_loader = torch.utils.data.DataLoader(
    #     dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    n_correctCTC = 0
    n_correctAttention = 0
    distanceCTC = 0
    distanceAttention = 0
    sum_charNum = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        tAttention, lAttention = converterAttention.encode(cpu_texts)
        utils.loadData(textAttention, tAttention)
        utils.loadData(lengthAttention, lAttention)
        tCTC, lCTC = converterCTC.encode(cpu_texts)
        utils.loadData(textCTC, tCTC)
        utils.loadData(lengthCTC, lCTC)
        # print (image)

        if opt.lang:
            predsCTC, predsAttention = crnn(image, lengthAttention,
                                            textAttention)
        else:
            predsCTC, predsAttention = crnn(imageAttention, lengthAttention)
        costAttention = criterionAttention(predsAttention, textAttention)
        preds_size = Variable(torch.IntTensor([predsCTC.size(0)] * batch_size))
        costCTC = criterionCTC(predsCTC, textCTC, preds_size,
                               lengthCTC) / batch_size
        loss_avg.add(costAttention)
        loss_avg.add(costCTC.cuda())

        _, predsAttention = predsAttention.max(1)
        predsAttention = predsAttention.view(-1)
        sim_predsAttention = converterAttention.decode(predsAttention.data,
                                                       lengthAttention.data)

        _, predsCTC = predsCTC.max(2)
        predsCTC = predsCTC.transpose(1, 0).contiguous().view(-1)
        sim_predsCTC = converterCTC.decode(predsCTC.data,
                                           preds_size.data,
                                           raw=False)

        for i, cpu_text in enumerate(cpu_texts):
            gtText = cpu_text.decode('utf-8')
            CTCText = sim_predsCTC[i]
            if isinstance(CTCText, str):
                CTCText = CTCText.decode('utf-8')
            AttentionText = sim_predsAttention[i]
            print('gtText: %s' % gtText)
            print('CTCText: %s' % CTCText)
            print('AttentionText: %s' % AttentionText)
            if gtText == CTCText:
                n_correctCTC += 1
            if gtText == AttentionText:
                n_correctAttention += 1
            distanceCTC += Levenshtein.distance(CTCText, gtText)
            distanceAttention += Levenshtein.distance(AttentionText, gtText)
            sum_charNum = sum_charNum + len(gtText)

    correctCTC_accuracy = n_correctCTC / float(max_iter * batch_size)
    cerCTC = distanceCTC / float(sum_charNum)
    print('Test CERCTC: %f, accuracyCTC: %f' % (cerCTC, correctCTC_accuracy))
    correctAttention_accuracy = n_correctAttention / float(
        max_iter * batch_size)
    cerAttention = distanceAttention / float(sum_charNum)
    print('Test CERAttention: %f, accuricyAttention: %f' %
          (cerAttention, correctAttention_accuracy))
Ejemplo n.º 10
0
alphabetChinese = 'WH1JFj47VzuowRnx2eiD3bAvpUgZKd8fINQctGqPOsTLSEBM9lX0YhaCrkmy56'

trainP, testP = train_test_split(roots, test_size=0.1)  ##此处未考虑字符平衡划分
traindataset = PathDataset(trainP, alphabetChinese)
testdataset = PathDataset(testP, alphabetChinese)
print(testdataset[0])
batchSize = 100
workers = 8
imgH = 32
imgW = 280
keep_ratio = True
cuda = True
ngpu = 1
nh = 256
sampler = randomSequentialSampler(traindataset, batchSize)
train_loader = torch.utils.data.DataLoader(traindataset,
                                           batch_size=batchSize,
                                           shuffle=False,
                                           sampler=None,
                                           num_workers=int(workers),
                                           collate_fn=alignCollate(
                                               imgH=imgH,
                                               imgW=imgW,
                                               keep_ratio=keep_ratio))

train_iter = iter(train_loader)


def weights_init(m):
    classname = m.__class__.__name__
Ejemplo n.º 11
0
def val(net, valdataset, criterionCTC, max_iter=100000):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_batchSize = 1
    # val_batchSize = opt.batchSize
    val_sampler = dataset.randomSequentialSampler(valdataset, val_batchSize)
    data_loader = torch.utils.data.DataLoader(valdataset,
                                              batch_size=val_batchSize,
                                              shuffle=False,
                                              sampler=val_sampler,
                                              num_workers=int(opt.workers),
                                              collate_fn=dataset.alignCollate(
                                                  imgH=opt.imgH,
                                                  imgW=opt.imgW,
                                                  keep_ratio=opt.keep_ratio))
    # data_loader = torch.utils.data.DataLoader(
    #     dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    n_correctCTC = 0
    distanceCTC = 0
    sum_charNum = 0
    sum_imgNum = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(data_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        sum_imgNum += batch_size
        #   print (type(cpu_images),type(cpu_texts))
        # print (cpu_images.size(),max_iter,len(cpu_texts))
        # exit(0)
        utils.loadData(image, cpu_images)
        #print (image)

        predsCTC = crnn(image)
        preds_size = Variable(torch.IntTensor([predsCTC.size(0)] * batch_size))

        _, predsCTC = predsCTC.max(2)
        predsCTC = predsCTC.transpose(1, 0).contiguous().view(-1)
        # print (predsCTC)
        sim_predsCTC = converterCTC.decode(predsCTC.data,
                                           preds_size.data,
                                           raw=False)
        # print (sim_predsCTC)
        #exit(0)

        for i, cpu_text in enumerate(cpu_texts):
            gtText = cpu_text.decode('utf-8')
            #CTCText = sim_predsCTC[i]
            CTCText = sim_predsCTC
            if isinstance(CTCText, str):
                CTCText = CTCText.decode('utf-8')
            if gtText == CTCText:
                n_correctCTC += 1
            distanceCTCline = Levenshtein.distance(CTCText, gtText)
            #if distaceCTCline/len(gtText) > 0.2:
            #if distaceCTCline/len(gtText) > 0.2:
            if gtText != CTCText:
                print('gtText: %s' % gtText)
                print('CTCText: %s' % CTCText)
            distanceCTC += distanceCTCline
            sum_charNum = sum_charNum + len(gtText)

    print('n_coorectCTC: %d, max_iter: %d, batch_size: %d' %
          (n_correctCTC, max_iter, batch_size))
    correctCTC_accuracy = n_correctCTC / float(sum_imgNum)
    cerCTC = distanceCTC / float(sum_charNum)
    print('Test CERCTC: %f, accuracyCTC: %f' % (cerCTC, correctCTC_accuracy))
    logger.info(opt)

    # tensorboardX
    writer = SummaryWriter(os.path.join(log_dir, 'tb_logs'))

    # store model path
    if not os.path.exists('./expr'):
        os.mkdir('./expr')
    os.environ["CUDA_VISIBLE_DEVICES"] = str(opt.GPU_ID)
    # read train set
    train_dataset = dataset.lmdbDataset(root=opt.trainroot,
                                        rgb=test_params.rgb,
                                        rand_hcrop=True)
    assert train_dataset
    if test_params.random_sample:
        sampler = dataset.randomSequentialSampler(train_dataset,
                                                  test_params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*160
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=test_params.batchSize,
        shuffle=False,
        sampler=sampler,
        num_workers=int(test_params.workers),
        collate_fn=dataset.alignCollate(imgH=test_params.imgH,
                                        imgW=test_params.imgW,
                                        keep_ratio=test_params.keep_ratio))

    # read test set
Ejemplo n.º 13
0
opt.manualSeed = random.randint(1, 10000)  # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

train_dataset = dataset.lmdbDataset(root=opt.trainroot)
assert train_dataset
if not opt.random_sample:
    sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
else:
    sampler = None
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=opt.batchSize,
    shuffle=True, sampler=sampler,
    num_workers=int(opt.workers),
    collate_fn=dataset.alignCollate(imgH=opt.imgH, keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(
    root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))

ngpu = int(opt.ngpu)
nh = int(opt.nh)
alphabet = opt.alphabet
nclass = len(alphabet) + 1
nc = 1
Ejemplo n.º 14
0
    manualSeed = random.randint(1, 10000)  # fix seed
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True
    
    # store model path
    if not os.path.exists('./expr'):
        os.mkdir('./expr')

    # read s_train set
    s_train_dataset = dataset.lmdbDataset(root=params.s_train_data)
    assert s_train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(s_train_dataset, params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*96
    s_train_loader = torch.utils.data.DataLoader(
        s_train_dataset, batch_size=params.batchSize,
        shuffle=True, sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=96, keep_ratio=params.keep_ratio))

    # read s_train set
    m_train_dataset = dataset.lmdbDataset(root=params.m_train_data)
    assert m_train_dataset
    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(m_train_dataset, params.batchSize)
    # read train set

    data_transform = transforms.Compose([
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
        transforms.RandomAffine(degrees=0, scale=(0.9, 1.07), shear=3),
    ])
    train_dataset = dataset.lmdbDataset(root=opt.trainroot,
                                        transform=data_transform)
    val_dataset = dataset.lmdbDataset(root=opt.valroot,
                                      transform=data_transform)
    concat_dataset = torch.utils.data.ConcatDataset(
        [train_dataset, val_dataset])

    if not params.random_sample:
        sampler = dataset.randomSequentialSampler(concat_dataset,
                                                  params.batchSize)
    else:
        sampler = None

    # images will be resize to 32*320
    train_loader = torch.utils.data.DataLoader(
        concat_dataset,
        batch_size=params.batchSize,
        shuffle=True,
        sampler=sampler,
        num_workers=int(params.workers),
        collate_fn=dataset.alignCollate(imgH=params.imgH,
                                        imgW=params.imgW,
                                        keep_ratio=params.keep_ratio))

    nclass = len(params.alphabet) + 1