Exemplo n.º 1
0
def val(net, val_loader, criterion, epoch, max_i=1000):
    print('================Start val=================')
    for p in crnn.parameters():
        p.requires_grad = False
    net.eval()
    i = 0
    n_correct = 0
    n_all = 0
    loss_avg = utils.averager()

    for i_batch, (image, index) in enumerate(val_loader):
        image = image.to(device)
        print('image.shape:', image.shape)
        label = utils.get_batch_label(val_dataset, index)
        # [41,batch,nclass]
        preds = crnn(image)
        batch_size = image.size(0)
        # index = np.array(index.data.numpy())
        label_text, label_length = converter.encode(label)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, label_text, preds_size,
                         label_length) / batch_size
        loss_avg.add(cost)
        # [41,batch]
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        # preds = preds.transpose(1, 0).reshape(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        print('label:', label[:2])
        print('sim_preds:', sim_preds[:2])
        # print(list(zip(sim_preds, label)))

        n_all += len(label)
        for pred, target in list(zip(sim_preds, label)):
            if pred == target:
                n_correct += 1

        if (i_batch + 1) % params.displayInterval == 0:
            print('[%d/%d][%d/%d]' %
                  (epoch, params.epochs, i_batch, len(val_loader)))
        if i_batch == max_i:
            break
    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:params.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, label):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
    #
    # print('n_correct:',n_correct)
    # accuracy = n_correct / float(max_i * params.val_batchSize)
    accuracy = n_correct / n_all
    print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
    return accuracy
Exemplo n.º 2
0
def val(net, criterion, max_iter=3):
    # print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    val_iter = iter(test_loader)

    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    max_iter = min(max_iter, len(test_loader))
    for i in range(max_iter):
        data = val_iter.next()
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        if ifUnicode:
             cpu_texts = [clean_txt(tx.decode('utf-8'))  for tx in cpu_texts]
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)

        _, preds = preds.max(2)
        preds = preds.squeeze(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred.strip() == target.strip():
                n_correct += 1

    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp]
    # for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
         # print((pred, gt))
    accuracy = n_correct / float(max_iter * opt.batchSize)
    testLoss = loss_avg.val()
    # print('Test loss: %f, accuray: %f' % (testLoss, accuracy))
    return testLoss, accuracy
Exemplo n.º 3
0
def main(crnn, train_loader, val_loader, criterion, optimizer):

    crnn = crnn.to(device)
    criterion = criterion.to(device)
    for i, epoch in enumerate(range(params.epochs)):
        # if i<1:
        train(crnn, train_loader, criterion, epoch)
        # # ## max_i: cut down the consuming time of testing, if you'd like to validate on the whole testset, please set it to len(val_loader)
        accuracy = val(crnn, val_loader, criterion, epoch, max_i=1000)
        for p in crnn.parameters():
            p.requires_grad = True
        # if accuracy > params.best_accuracy:
        torch.save(
            crnn.state_dict(),
            '{0}/crnn_Rec_done_{1}_{2}.pth'.format(params.experiment, epoch,
                                                   accuracy))
        torch.save(crnn.state_dict(),
                   '{0}/crnn_best.pth'.format(params.experiment))
        print("is best accuracy: {0}".format(accuracy > params.best_accuracy))
Exemplo n.º 4
0
def trainBatch(net, criterion, optimizer, flage=False):
    data = train_iter.next()
    cpu_images, cpu_texts = data  # decode utf-8 to unicode
    if ifUnicode:
        cpu_texts = [clean_txt(tx.decode('utf-8'))  for tx in cpu_texts]
    
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    cost = criterion(preds, text, preds_size, length) / batch_size
    crnn.zero_grad()
    cost.backward()
    if flage:
        lr = 0.0001
        optimizer = optim.Adadelta(crnn.parameters(), lr=lr)
    optimizer.step()
    return cost
Exemplo n.º 5
0
def train(crnn, train_loader, criterion, epoch):
    for p in crnn.parameters():
        p.requires_grad = True
    crnn.train()
    #loss averager
    loss_avg = utils.averager()
    for i_batch, (image, index) in enumerate(train_loader):
        #[b,c,h,w] [32,1,32,160]
        image = image.to(device)
        print('image.shape:', image.shape)
        batch_size = image.size(0)
        #['xxx','xxxx',...batch]
        label = utils.get_batch_label(dataset, index)
        #[41,batch,nclass]
        preds = crnn(image)
        # print('preds.shape',preds.shape)
        # index = np.array(index.data.numpy())
        #[, , ,]    [len(lable[0]),len(lable[1]),...]
        label_text, label_length = converter.encode(label)
        # print('label_text:', len(label_text))
        # print('label_length:', label_length)
        #[41,41,41,...]*batch
        preds_size = torch.IntTensor([preds.size(0)] * batch_size)
        # print('preds.shape, label_text.shape, preds_size.shape, label_length.shape',preds.shape, label_text.shape, preds_size.shape, label_length.shape)
        # torch.Size([41, 32, 6736]) torch.Size([320]) torch.Size([320]) torch.Size([320])
        cost = criterion(preds, label_text, preds_size,
                         label_length) / batch_size
        # print('cost:',cost)
        crnn.zero_grad()
        cost.backward()
        optimizer.step()

        loss_avg.add(cost)

        if (i_batch + 1) % params.displayInterval == 0:
            print('[%d/%d][%d/%d] Loss: %f' %
                  (epoch, params.epochs, i_batch, len(train_loader),
                   loss_avg.val()))
            loss_avg.reset()
Exemplo n.º 6
0
    crnn.cuda()
    # crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
    # image = image.cuda()
    device = torch.device('cuda:0')
    criterion = criterion.cuda()

# image = Variable(image)
# text = Variable(text)
# length = Variable(length)

# loss averager
loss_avg = utils.averager()

# setup optimizer
if config.adam:
    optimizer = optim.Adam(crnn.parameters(),
                           lr=config.lr,
                           betas=(config.beta1, 0.999))
elif config.adadelta:
    optimizer = optim.Adadelta(crnn.parameters(), lr=config.lr)
else:
    optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)


def val(net, dataset, criterion, max_iter=100):
    print('Start val')
    for p in net.parameters():
        p.requires_grad = False

    num_correct, num_all = val_model(config.val_infofile,
                                     net,
Exemplo n.º 7
0
    crnn.cuda()
    # crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))
    # image = image.cuda()
    device = torch.device('cuda:0')
    criterion = criterion.cuda()

# image = Variable(image)
# text = Variable(text)
# length = Variable(length)

# loss averager
loss_avg = utils.averager()

# setup optimizer
if config.adam:
    optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
elif config.adadelta:
    optimizer = optim.Adadelta(crnn.parameters(), lr=config.lr)
else:
    optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)


def val(net, dataset, criterion, max_iter=100):
    print('Start val')
    for p in net.parameters():
        p.requires_grad = False

    num_correct,  num_all = val_model(config.val_infofile,net,True,log_file='compare-'+config.saved_model_prefix+'.log')
    accuracy = num_correct / num_all

    print('ocr_acc: %f' % (accuracy))
Exemplo n.º 8
0
Arquivo: train.py Projeto: FLming/CRNN
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    num_classes = len(alphabet.alphabet) + 1
    converter = utils.StrLabelConverter(alphabet.alphabet)

    trainloader, validloader = prepare_dataloader()

    crnn = crnn.CRNN(num_classes).to(device)

    criterion = torch.nn.CTCLoss().to(device)
    if args.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=args.lr)
    elif args.rmsprop:
        optimizer = optim.RMSprop(crnn.parameters(), lr=args.lr)
    else:
        optimizer = optim.Adadelta(crnn.parameters())

    if args.pretrained != '':
        print('loading pretrained model from {}'.format(args.pretrained))
        crnn.load_state_dict(torch.load(args.pretrained))

    crnn.train()
    for epoch in range(args.num_epoch):

        train(trainloader, crnn, converter, criterion, optimizer)

        if epoch % args.eval_epoch == 0: