Ejemplo n.º 1
0
def main(crnn, train_loader, val_loader, criterion, optimizer):

    crnn = crnn.to(device)
    certerion = criterion.to(device)
    Iteration = 0
    params.best_accuracy = 0.0
    while Iteration < params.niter:
        train(crnn, train_loader, criterion, Iteration)
        ## max_i: cut down the consuming time of testing, if you'd like to validate on the whole testset, please set it to len(val_loader)
        accuracy = val(crnn, val_loader, criterion, Iteration, max_i=1000)
        for p in crnn.parameters():
            p.requires_grad = True

        if Iteration % 50 == 1:
            print("saving checkpoint...")
            torch.save(
                crnn.state_dict(),
                '{0}/crnn_Rec_done_{1}_{2}.pth'.format(params.experiment,
                                                       Iteration, accuracy))
            print("done")
        if accuracy > params.best_accuracy:
            params.best_accuracy = accuracy
            print('saving best acc....')
            torch.save(
                crnn.state_dict(),
                '{0}/crnn_Rec_done_{1}_{2}.pth'.format(params.experiment,
                                                       Iteration, accuracy))
            torch.save(crnn.state_dict(),
                       '{0}/crnn_best.pth'.format(params.experiment))
            print('done')
        # print("is best accuracy: {0}".format(accuracy > params.best_accuracy))
        Iteration += 1
Ejemplo n.º 2
0
def main():

    if not os.path.exists(opt.output):
        os.makedirs(opt.output)

    converter = utils.strLabelConverter(opt.alphabet)

    collate = dataset.AlignCollate()
    train_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batchsize, shuffle=True,
                                               num_workers=opt.num_workers, collate_fn=collate)
    test_dataset = dataset.TextLineDataset(text_file=opt.train_list, transform=dataset.ResizeNormalize(100, 32), converter=converter)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=opt.batchsize,
                                              num_workers=opt.num_workers, collate_fn=collate)

    criterion = nn.CTCLoss()

    import models.crnn as crnn

    crnn = crnn.CRNN(opt.imgH, opt.nc, opt.num_classes, opt.nh)
    crnn.apply(utils.weights_init)
    if opt.pretrained != '':
        print('loading pretrained model from %s' % opt.pretrained)
        crnn.load_state_dict(torch.load(opt.pretrained), strict=False)
    print(crnn)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    crnn = crnn.to(device)
    criterion = criterion.to(device)


    # setup optimizer
    optimizer = optim.Adam(crnn.parameters(), lr=opt.lr)

    for epoch in range(opt.num_epochs):

        loss_avg = 0.0
        i = 0
        while i < len(train_loader):

            time0 = time.time()
            # 训练
            train_iter = iter(train_loader)

            cost = trainBatch(crnn, train_iter, criterion, optimizer, device) # 一个批次,一个批次训练
            loss_avg += cost
            i += 1

            if i % opt.interval == 0:
                print('[%d/%d][%d/%d] Loss: %f Time: %f s' %
                      (epoch, opt.num_epochs, i, len(train_loader), loss_avg,
                       time.time() - time0))
                loss_avg = 0.0



        if (epoch + 1) % opt.valinterval == 0:
            val(crnn, test_loader, criterion, converter=converter, device=device, max_iter=100)
Ejemplo n.º 3
0
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.crnn != '':
    print('loading pretrained model from %s' % opt.crnn)
    crnn.load_state_dict(torch.load(opt.crnn), strict=False)
print(crnn)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
crnn = crnn.to(device)
criterion = criterion.to(device)

# loss averager
loss_avg = utils.averager()

# setup optimizer
# optimizer = optim.Adam(crnn.parameters(), lr=opt.lr, weight_decay=1e-4)
optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=2000,
                                            gamma=0.3)


def val(net, data_loader, criterion, max_iter=100):
    print('Start val')
Ejemplo n.º 4
0
crnn.apply(weights_init)

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)

if opt.cuda != '-1':
    str_ids = opt.cuda.split(",")
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        crnn.to(gpu_ids[0])
        crnn = torch.nn.DataParallel(crnn, device_ids=gpu_ids)
        image = image.to(gpu_ids[0])
        criterion = criterion.to(gpu_ids[0])
if opt.pretrained > -1:
    model_path = '{0}/netCRNN_{1}.pth'.format(opt.expr_dir, opt.pretrained)
    print('loading pretrained model from %s' % model_path)
    # crnn.load_state_dict(torch.load(opt.pretrained))
    crnn.load_state_dict(torch.load(model_path))
print(crnn)

image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
Ejemplo n.º 5
0
crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh)
crnn.apply(weights_init)
if opt.pretrained != '':
    print('loading pretrained model from %s' % opt.pretrained)
    crnn.load_state_dict(torch.load(opt.pretrained))
print(crnn)

image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgW)
text = torch.IntTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)


if torch.cuda.device_count() > 1:
    crnn = nn.DataParallel(crnn)
crnn.to(device)
image = image.to(device)
criterion = criterion.to(device)

image = image.to(device)
# text = text.to(device)
# length = length.to(device)

image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.averager()
epoch_loss_avg = utils.averager()