コード例 #1
0
ファイル: main.py プロジェクト: LiqunW/MORAN
def trainBatch():
    # 获取一个batch的数据 [images,label]
    data = train_iter.next()
    if opt.BidirDecoder:
        cpu_images, cpu_texts, cpu_texts_rev = data
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts, scanned=True)
        t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
        utils.loadData(text, t)
        utils.loadData(text_rev, t_rev)
        utils.loadData(length, l)
        # 双向lstm有两个结果
        preds0, preds1 = MORAN(image, length, text, text_rev)
        cost = criterion(torch.cat([preds0, preds1], 0),
                         torch.cat([text, text_rev], 0))
    else:
        cpu_images, cpu_texts = data
        utils.loadData(image, cpu_images)
        # 标签和每个标签的长度
        t, l = converter.encode(cpu_texts, scanned=True)
        utils.loadData(text, t)
        utils.loadData(length, l)
        # 单向lstm一个结果
        preds = MORAN(image, length, text, text_rev)
        cost = criterion(preds, text)

    MORAN.zero_grad()
    cost.backward()
    optimizer.step()
    return cost
コード例 #2
0
ファイル: main.py プロジェクト: Muran337287/MORAN_v2
def trainBatch():
    data = train_iter.next()
    if opt.BidirDecoder:
        cpu_images, cpu_texts, cpu_texts_rev = data  #读取标签数据
        utils.loadData(image, cpu_images)  #将图像数据赋值给image
        t, l = converter.encode(cpu_texts, scanned=True)  #将文本编码成类别标签
        t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)  #将反向文本编码成类别标签
        utils.loadData(text, t)  #将正向文本标签赋予t
        utils.loadData(text_rev, t_rev)  #将反向文本标签赋予t_rev
        utils.loadData(length, l)
        preds0, preds1 = MORAN(image, length, text, text_rev)  #输出正向和反向识别概率
        cost = criterion(torch.cat([preds0, preds1], 0),
                         torch.cat([text, text_rev], 0))  #计算交叉熵损失
    else:
        cpu_images, cpu_texts = data
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts, scanned=True)
        utils.loadData(text, t)
        utils.loadData(length, l)
        preds = MORAN(image, length, text, text_rev)
        cost = criterion(preds, text)

    MORAN.zero_grad()
    cost.backward()
    optimizer.step()
    return cost
コード例 #3
0
ファイル: HARN.py プロジェクト: happog/FudanOCR
    def trainBatch(steps):
        data = train_iter.next()
        if opt.BidirDecoder:
            cpu_images, cpu_texts, cpu_texts_rev = data
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts, scanned=True)
            t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
            utils.loadData(text, t)
            utils.loadData(text_rev, t_rev)
            utils.loadData(length, l)
            preds0, preds1 = MORAN(image, length, text, text_rev)
            cost = criterion(torch.cat([preds0, preds1], 0),
                             torch.cat([text, text_rev], 0))
        else:
            cpu_images, cpu_texts = data
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts, scanned=True)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = MORAN(image, length, text, text_rev)
            cost = criterion(preds, text)

        MORAN.zero_grad()
        cost.backward()  # 反向传播
        optimizer.step()  # 优化器
        return cost
コード例 #4
0
def trainBatch():
    data = train_iter.next()
    cpu_images, cpu_texts, cpu_texts_rev = data
    # utils.loadData(image, encode_coordinates_fn(cpu_images))
    utils.loadData(image, cpu_images)
    t, l = converter.encode(cpu_texts, scanned=True)
    t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
    utils.loadData(text, t)
    utils.loadData(text_rev, t_rev)
    utils.loadData(length, l)
    preds0, preds1 = MORAN(image, length, text, text_rev)
    cost = criterion(torch.cat([preds0, preds1], 0),
                     torch.cat([text, text_rev], 0))

    MORAN.zero_grad()
    cost.backward()
    optimizer.step()
    return cost