def trainBatch(): # 获取一个batch的数据 [images,label] data = train_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) # 双向lstm有两个结果 preds0, preds1 = MORAN(image, length, text, text_rev) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) # 标签和每个标签的长度 t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) # 单向lstm一个结果 preds = MORAN(image, length, text, text_rev) cost = criterion(preds, text) MORAN.zero_grad() cost.backward() optimizer.step() return cost
def trainBatch(): data = train_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data #读取标签数据 utils.loadData(image, cpu_images) #将图像数据赋值给image t, l = converter.encode(cpu_texts, scanned=True) #将文本编码成类别标签 t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) #将反向文本编码成类别标签 utils.loadData(text, t) #将正向文本标签赋予t utils.loadData(text_rev, t_rev) #将反向文本标签赋予t_rev utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev) #输出正向和反向识别概率 cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) #计算交叉熵损失 else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev) cost = criterion(preds, text) MORAN.zero_grad() cost.backward() optimizer.step() return cost
def trainBatch(steps): data = train_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev) cost = criterion(preds, text) MORAN.zero_grad() cost.backward() # 反向传播 optimizer.step() # 优化器 return cost
def trainBatch(): data = train_iter.next() cpu_images, cpu_texts, cpu_texts_rev = data # utils.loadData(image, encode_coordinates_fn(cpu_images)) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) MORAN.zero_grad() cost.backward() optimizer.step() return cost