Esempio n. 1
0
def val(dataset, criterion, max_iter=1000):
    print('Start val')
    data_loader = torch.utils.data.DataLoader(
        dataset, shuffle=False, batch_size=opt.batchSize, num_workers=int(opt.workers)) # opt.batchSize
    val_iter = iter(data_loader)
    max_iter = min(max_iter, len(data_loader))
    n_correct = 0
    n_total = 0
    loss_avg = utils.averager()
    
    for i in range(max_iter):
        data = val_iter.next()
        if opt.BidirDecoder:
            cpu_images, cpu_texts, cpu_texts_rev = data
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts, scanned=True)
            t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
            utils.loadData(text, t)
            utils.loadData(text_rev, t_rev)
            utils.loadData(length, l)
            preds0, preds1 = MORAN(image, length, text, text_rev, test=True)
            cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0))
            preds0_prob, preds0 = preds0.max(1)
            preds0 = preds0.view(-1)
            preds0_prob = preds0_prob.view(-1)
            sim_preds0 = converter.decode(preds0.data, length.data)
            preds1_prob, preds1 = preds1.max(1)
            preds1 = preds1.view(-1)
            preds1_prob = preds1_prob.view(-1)
            sim_preds1 = converter.decode(preds1.data, length.data)
            sim_preds = []
            for j in range(cpu_images.size(0)):
                text_begin = 0 if j == 0 else length.data[:j].sum()
                if torch.mean(preds0_prob[text_begin:text_begin+len(sim_preds0[j].split('$')[0]+'$')]).data[0] >\
                 torch.mean(preds1_prob[text_begin:text_begin+len(sim_preds1[j].split('$')[0]+'$')]).data[0]:
                    sim_preds.append(sim_preds0[j].split('$')[0]+'$')
                else:
                    sim_preds.append(sim_preds1[j].split('$')[0][-1::-1]+'$')
        else:
            cpu_images, cpu_texts = data
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts, scanned=True)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = MORAN(image, length, text, text_rev, test=True)
            cost = criterion(preds, text)
            _, preds = preds.max(1)
            preds = preds.view(-1)
            sim_preds = converter.decode(preds.data, length.data)

        loss_avg.add(cost)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred == target.lower():
                n_correct += 1
            n_total += 1

    print("correct / total: %d / %d, "  % (n_correct, n_total))
    accuracy = n_correct / float(n_total)
    print('Test loss: %f, : %f' % (loss_avg.val(), accuracy))
    return accuracy
Esempio n. 2
0
    def val(dataset, criterion, max_iter=10000, steps=None):
        data_loader = torch.utils.data.DataLoader(
            dataset,
            shuffle=False,
            batch_size=opt.batchSize,
            num_workers=int(opt.workers))  # opt.batchSize
        val_iter = iter(data_loader)
        max_iter = min(max_iter, len(data_loader))
        n_correct = 0
        n_total = 0
        distance = 0.0
        loss_avg = utils.averager()

        # f = open('./log.txt', 'a', encoding='utf-8')

        for i in range(max_iter):  # 设置很大的循环数值(达不到此值就会收敛)
            data = val_iter.next()
            if opt.BidirDecoder:
                cpu_images, cpu_texts, cpu_texts_rev = data  # data是dataloader导入的东西
                utils.loadData(image, cpu_images)
                t, l = converter.encode(cpu_texts,
                                        scanned=False)  # 这个encode是将字符encode成id
                t_rev, _ = converter.encode(cpu_texts_rev, scanned=False)
                utils.loadData(text, t)
                utils.loadData(text_rev, t_rev)
                utils.loadData(length, l)
                preds0, preds1 = MORAN(image,
                                       length,
                                       text,
                                       text_rev,
                                       debug=False,
                                       test=True,
                                       steps=steps)  # 跑模型HARN
                cost = criterion(torch.cat([preds0, preds1], 0),
                                 torch.cat([text, text_rev], 0))
                preds0_prob, preds0 = preds0.max(1)  # 取概率最大top1的结果
                preds0 = preds0.view(-1)
                preds0_prob = preds0_prob.view(-1)  # 维度的变形(好像是
                sim_preds0 = converter.decode(preds0.data,
                                              length.data)  # 将 id decode为字
                preds1_prob, preds1 = preds1.max(1)
                preds1 = preds1.view(-1)
                preds1_prob = preds1_prob.view(-1)
                sim_preds1 = converter.decode(preds1.data, length.data)
                sim_preds = []  # 预测出来的字
                for j in range(cpu_images.size(0)):  # 对字典进行处理,把单个字符连成字符串
                    text_begin = 0 if j == 0 else length.data[:j].sum()
                    if torch.mean(preds0_prob[text_begin:text_begin + len(sim_preds0[j].split('$')[0] + '$')]).item() > \
                            torch.mean(
                                preds1_prob[text_begin:text_begin + len(sim_preds1[j].split('$')[0] + '$')]).item():
                        sim_preds.append(sim_preds0[j].split('$')[0] + '$')
                    else:
                        sim_preds.append(sim_preds1[j].split('$')[0][-1::-1] +
                                         '$')
            else:  # 用不到的另一种情况
                cpu_images, cpu_texts = data
                utils.loadData(image, cpu_images)
                t, l = converter.encode(cpu_texts, scanned=True)
                utils.loadData(text, t)
                utils.loadData(length, l)
                preds = MORAN(image, length, text, text_rev, test=True)
                cost = criterion(preds, text)
                _, preds = preds.max(1)
                preds = preds.view(-1)
                sim_preds = converter.decode(preds.data, length.data)

            loss_avg.add(cost)  # 计算loss的平均值
            for pred, target in zip(
                    sim_preds, cpu_texts
            ):  # 与GroundTruth的对比,cpu_texts是GroundTruth,sim_preds是连接起来的字符串
                if pred == target.lower():  # 完全匹配量
                    n_correct += 1
                # f.write("pred %s\t      target %s\n" % (pred, target))
                distance += levenshtein(pred, target) / max(
                    len(pred), len(target))  # 莱温斯坦距离
                n_total += 1  # 完成了一个单词

        # f.close()

        # print and save     # 跑完之后输出到日志中
        for pred, gt in zip(sim_preds, cpu_texts):
            gt = ''.join(gt.split(opt.sep))
            print('%-20s, gt: %-20s' % (pred, gt))

        print("correct / total: %d / %d, " % (n_correct, n_total))
        print('levenshtein distance: %f' % (distance / n_total))
        accuracy = n_correct / float(n_total)
        log.scalar_summary('Validation/levenshtein distance',
                           distance / n_total, steps)
        log.scalar_summary('Validation/loss', loss_avg.val(), steps)
        log.scalar_summary('Validation/accuracy', accuracy, steps)
        print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
        return accuracy
Esempio n. 3
0
def val(dataset, criterion, max_iter=10000, steps=0):
    data_loader = torch.utils.data.DataLoader(
        dataset,
        shuffle=False,
        batch_size=args.batchSize,
        num_workers=args.workers)  # args.batchSize
    val_iter = iter(data_loader)
    max_iter = min(max_iter, len(data_loader))
    n_correct = 0
    n_total = 0
    distance = 0.0
    loss_avg = utils.averager()

    f = open('logger/log.txt', 'w', encoding='utf-8')

    for i in range(max_iter):
        data = val_iter.next()
        cpu_images, cpu_texts, cpu_texts_rev = data
        # utils.loadData(image, encode_coordinates_fn(cpu_images))
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts, scanned=True)
        t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
        utils.loadData(text, t)
        utils.loadData(text_rev, t_rev)
        utils.loadData(length, l)
        preds0, _, preds1, _ = MORAN(image,
                                     length,
                                     text,
                                     text_rev,
                                     debug=False,
                                     test=True,
                                     steps=steps)
        cost = criterion(torch.cat([preds0, preds1], 0),
                         torch.cat([text, text_rev], 0))
        preds0_prob, preds0 = preds0.max(1)
        preds0 = preds0.view(-1)
        preds0_prob = preds0_prob.view(-1)
        sim_preds0 = converter.decode(preds0.data, length.data)
        preds1_prob, preds1 = preds1.max(1)
        preds1 = preds1.view(-1)
        preds1_prob = preds1_prob.view(-1)
        sim_preds1 = converter.decode(preds1.data, length.data)
        sim_preds = []
        for j in range(cpu_images.size(0)):
            text_begin = 0 if j == 0 else length.data[:j].sum()
            if torch.mean(preds0_prob[text_begin:text_begin + len(sim_preds0[j].split('$')[0] + '$')]).item() > \
                    torch.mean(preds1_prob[text_begin:text_begin + len(sim_preds1[j].split('$')[0] + '$')]).item():
                sim_preds.append(sim_preds0[j].split('$')[0] + '$')
            else:
                sim_preds.append(sim_preds1[j].split('$')[0][-1::-1] + '$')

        # img_shape = cpu_images.shape[3] / 100, cpu_images.shape[2] / 100
        # input_seq = cpu_texts[0]
        # output_seq = sim_preds[0]
        # attention = alpha[0]
        # attention_image = showAttention(input_seq, output_seq, attention, img_shape)
        # log.image_summary('map/attention', [attention_image], steps)

        loss_avg.add(cost)
        for pred, target in zip(sim_preds, cpu_texts):
            if pred == target.lower():
                n_correct += 1
            f.write("pred %s\t\t\t\t\ttarget %s\n" % (pred, target))
            distance += levenshtein(pred, target) / max(len(pred), len(target))
            n_total += 1

    f.close()

    accuracy = n_correct / float(n_total)
    log.scalar_summary('Validation/levenshtein distance', distance / n_total,
                       steps)
    log.scalar_summary('Validation/loss', loss_avg.val(), steps)
    log.scalar_summary('Validation/accuracy', accuracy, steps)
    return accuracy