Esempio n. 1
0
def trainBatch(net, optimizer):
    # print('train_iter: ', train_iter)
    # print('len_train_iter: ', len(train_iter))
    # print(type(train_iter.next()))
    data = train_iter.next()
    cpu_images, cpu_texts = data
    batch_size = cpu_images.size(0)
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts)
    utils.loadData(text, t)
    utils.loadData(length, l)

    preds = crnn(image)
    preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
    H, cost = seg_ctc_ent_cost(preds,
                               text,
                               preds_size,
                               length,
                               uni_rate=opt.uni_rate)
    h_cost = (1 - opt.h_rate) * cost - opt.h_rate * H
    cost_sum = h_cost.data.sum()
    inf = float("inf")
    if cost_sum == inf or cost_sum == -inf or cost_sum > 200 * batch_size:
        print("Warning: received an inf loss, setting loss value to 0")
        return torch.zeros(H.size()), torch.zeros(cost.size()), torch.zeros(
            h_cost.size())

    crnn.zero_grad()
    h_cost.backward()
    torch.nn.utils.clip_grad_norm(crnn.parameters(), opt.max_norm)
    optimizer.step()
    return H / batch_size, cost / batch_size, h_cost / batch_size
Esempio n. 2
0
def val(net, dataset, max_iter=100):
    print('Start val')

    for p in crnn.parameters():
        p.requires_grad = False

    net.eval()
    data_loader = torch.utils.data.DataLoader(dataset,
                                              shuffle=True,
                                              batch_size=opt.batchSize,
                                              num_workers=int(opt.workers))
    val_iter = iter(data_loader)

    i = 0
    n_correct = 0
    # loss averager
    avg_h_val = utils.averager()
    avg_cost_val = utils.averager()
    avg_h_cost_val = utils.averager()

    if opt.eval_all:
        max_iter = len(data_loader)
    else:
        max_iter = min(max_iter, len(data_loader))

    for i in range(max_iter):
        data = val_iter.next()
        # print('data: ', data)
        # print(data)
        i += 1
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.loadData(text, t)
        utils.loadData(length, l)
        # print('len(cpu_images): ', len(cpu_images))
        # print('cpu_texts ', cpu_texts)
        # print('len(cpu_texts): ', len(cpu_texts))
        # print('l ', l)
        # print(len(l))
        # print('length ', length)

        preds = crnn(image)  # size = 26, 64, 96
        # print(preds.size())
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        H, cost = seg_ctc_ent_cost(preds,
                                   text,
                                   preds_size,
                                   length,
                                   uni_rate=opt.uni_rate)
        h_cost = (1 - opt.h_rate) * cost - opt.h_rate * H
        avg_h_val.add(H / batch_size)
        avg_cost_val.add(cost / batch_size)
        avg_h_cost_val.add(h_cost / batch_size)

        _, preds = preds.max(2)  # size = 26, 64
        # print(preds.size())
        # preds = preds.squeeze(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
        for idx, (pred, target) in enumerate(zip(sim_preds, cpu_texts)):
            if pred == target.lower():
                n_correct += 1

    raw_preds = converter.decode(preds.data, preds_size.data,
                                 raw=True)[:opt.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts):
        print('%-30s => %-30s, gt: %-30s' % (raw_pred, pred, gt))

    accuracy = n_correct / float(max_iter * opt.batchSize)
    print(
        'Test H: %f, Cost: %f, H Cost: %f, accuray: %f' %
        (avg_h_val.val(), avg_cost_val.val(), avg_h_cost_val.val(), accuracy))