state_dict = torch.load(opt.MORAN) else: state_dict = torch.load(opt.MORAN, map_location='cpu') MORAN_state_dict_rename = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") # remove `module.` MORAN_state_dict_rename[name] = v MORAN.load_state_dict(MORAN_state_dict_rename, strict=True) # 创建图像[batch,channel,h,w]和标签的tensor image = torch.FloatTensor(opt.batchSize, nc, opt.imgH, opt.imgW) text = torch.LongTensor(opt.batchSize * 5) text_rev = torch.LongTensor(opt.batchSize * 5) length = torch.IntTensor(opt.batchSize) if opt.cuda: MORAN.cuda() MORAN = torch.nn.DataParallel(MORAN, device_ids=range(opt.ngpu)) image = image.cuda() text = text.cuda() text_rev = text_rev.cuda() criterion = criterion.cuda() image = Variable(image) text = Variable(text) text_rev = Variable(text_rev) length = Variable(length) # loss averager loss_avg = utils.averager() # setup optimizer
class Recognizer: def __init__(self, model_path): alphabet = '0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$' self.cuda_flag = torch.cuda.is_available() if self.cuda_flag: self.MORAN = MORAN(1, len(alphabet.split(':')), 256, 32, 100, BidirDecoder=True, CUDA=self.cuda_flag) self.MORAN = self.MORAN.cuda() else: self.MORAN = MORAN(1, len(alphabet.split(':')), 256, 32, 100, BidirDecoder=True, inputDataType='torch.FloatTensor', CUDA=self.cuda_flag) print('loading pretrained model from %s' % model_path) if self.cuda_flag: state_dict = torch.load(model_path) else: state_dict = torch.load(model_path, map_location='cpu') MORAN_state_dict_rename = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") # remove `module.` MORAN_state_dict_rename[name] = v self.MORAN.load_state_dict(MORAN_state_dict_rename) for p in self.MORAN.parameters(): p.requires_grad = False self.MORAN.eval() self.converter = utils.strLabelConverterForAttention(alphabet, ':') self.transformer = dataset.resizeNormalize((100, 32)) def preprocess(self, img): image = Image.fromarray(img[..., ::-1]).convert('L') image = self.transformer(image) image = image.view(1, *image.size()) return image def predict(self, img_batch): batch_size = int(img_batch.size(0)) if self.cuda_flag: img_batch = img_batch.cuda() # img_batch = Variable(img_batch) text = torch.LongTensor(batch_size * 5) length = torch.IntTensor(batch_size) # text = Variable(text) # length = Variable(length) max_iter = 20 t, l = self.converter.encode(['0' * max_iter] * batch_size) utils.loadData(text, t) utils.loadData(length, l) output = self.MORAN(img_batch, length, text, text, test=True, debug=True) return output, length def post_process(self, output, length): preds, preds_reverse = output[0] # demo = output[1] _, preds = preds.max(1) _, preds_reverse = preds_reverse.max(1) sim_preds = self.converter.decode(preds.data, length.data) sim_preds = list(map(lambda x: x.strip().split('$')[0], sim_preds)) sim_preds_reverse = self.converter.decode(preds_reverse.data, length.data) sim_preds_reverse = list( map(lambda x: x.strip().split('$')[0], sim_preds_reverse)) return sim_preds, sim_preds_reverse def __call__(self, images): unit_size = len(images) == 1 if unit_size: images = images * 2 img_tensors = [] for img in images: img_tensors.append(self.preprocess(img)) img_batch = torch.cat(img_tensors) output, length = self.predict(img_batch) sim_preds, sim_preds_reverse = self.post_process(output, length) if unit_size: sim_preds = sim_preds[:1] return sim_preds
def train_moran_v2(config_file): import sys sys.path.append('./recognition_model/MORAN_V2') import argparse import random import torch import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.autograd import Variable import numpy as np import os import tools.utils as utils import tools.dataset as dataset import time from collections import OrderedDict from models.moran import MORAN from alphabet.wordlist import result from yacs.config import CfgNode as CN # from wordlistart import result from alphabet.wordlistlsvt import result def read_config_file(config_file): # 用yaml重构配置文件 f = open(config_file) opt = CN.load_cfg(f) return opt opt = read_config_file(config_file) # Modify opt.alphabet = result assert opt.ngpu == 1, "Multi-GPU training is not supported yet, due to the variant lengths of the text in a batch." if opt.experiment is None: opt.experiment = 'expr' os.system('mkdir {0}'.format(opt.experiment)) opt.manualSeed = random.randint(1, 10000) # fix seed print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) np.random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) cudnn.benchmark = True print(opt) if not torch.cuda.is_available(): assert not opt.cuda, 'You don\'t have a CUDA device.' if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") train_nips_dataset = dataset.lmdbDataset(root=opt.train_nips, transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder) assert train_nips_dataset ''' train_cvpr_dataset = dataset.lmdbDataset(root=opt.train_cvpr, transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder) assert train_cvpr_dataset ''' ''' train_dataset = torch.utils.data.ConcatDataset([train_nips_dataset, train_cvpr_dataset]) ''' train_dataset = train_nips_dataset train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batchSize, shuffle=False, sampler=dataset.randomSequentialSampler(train_dataset, opt.batchSize), num_workers=int(opt.workers)) test_dataset = dataset.lmdbDataset(root=opt.valroot, transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder) nclass = len(opt.alphabet.split(opt.sep)) nc = 1 converter = utils.strLabelConverterForAttention(opt.alphabet, opt.sep) criterion = torch.nn.CrossEntropyLoss() if opt.cuda: MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, CUDA=opt.cuda) else: MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, inputDataType='torch.FloatTensor', CUDA=opt.cuda) if opt.MORAN != '': print('loading pretrained model from %s' % opt.MORAN) if opt.cuda: state_dict = torch.load(opt.MORAN) else: state_dict = torch.load(opt.MORAN, map_location='cpu') MORAN_state_dict_rename = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") # remove `module.` MORAN_state_dict_rename[name] = v MORAN.load_state_dict(MORAN_state_dict_rename, strict=True) image = torch.FloatTensor(opt.batchSize, nc, opt.imgH, opt.imgW) text = torch.LongTensor(opt.batchSize * 5) text_rev = torch.LongTensor(opt.batchSize * 5) length = torch.IntTensor(opt.batchSize) if opt.cuda: MORAN.cuda() MORAN = torch.nn.DataParallel(MORAN, device_ids=range(opt.ngpu)) image = image.cuda() text = text.cuda() text_rev = text_rev.cuda() criterion = criterion.cuda() image = Variable(image) text = Variable(text) text_rev = Variable(text_rev) length = Variable(length) # loss averager loss_avg = utils.averager() # setup optimizer if opt.adam: optimizer = optim.Adam(MORAN.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.adadelta: optimizer = optim.Adadelta(MORAN.parameters(), lr=opt.lr) elif opt.sgd: optimizer = optim.SGD(MORAN.parameters(), lr=opt.lr, momentum=0.9) else: optimizer = optim.RMSprop(MORAN.parameters(), lr=opt.lr) def levenshtein(s1, s2): if len(s1) < len(s2): return levenshtein(s2, s1) # len(s1) >= len(s2) if len(s2) == 0: return len(s1) previous_row = range(len(s2) + 1) for i, c1 in enumerate(s1): current_row = [i + 1] for j, c2 in enumerate(s2): insertions = previous_row[j + 1] + 1 deletions = current_row[j] + 1 substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row return previous_row[-1] def val(dataset, criterion, max_iter=1000): print('Start val') data_loader = torch.utils.data.DataLoader( dataset, shuffle=False, batch_size=opt.batchSize, num_workers=int(opt.workers)) # opt.batchSize val_iter = iter(data_loader) max_iter = min(max_iter, len(data_loader)) n_correct = 0 n_total = 0 distance = 0.0 loss_avg = utils.averager() f = open('./log.txt','a',encoding='utf-8') for i in range(max_iter): data = val_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev, test=True) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) preds0_prob, preds0 = preds0.max(1) preds0 = preds0.view(-1) preds0_prob = preds0_prob.view(-1) sim_preds0 = converter.decode(preds0.data, length.data) preds1_prob, preds1 = preds1.max(1) preds1 = preds1.view(-1) preds1_prob = preds1_prob.view(-1) sim_preds1 = converter.decode(preds1.data, length.data) sim_preds = [] for j in range(cpu_images.size(0)): text_begin = 0 if j == 0 else length.data[:j].sum() if torch.mean(preds0_prob[text_begin:text_begin+len(sim_preds0[j].split('$')[0]+'$')]).data[0] >\ torch.mean(preds1_prob[text_begin:text_begin+len(sim_preds1[j].split('$')[0]+'$')]).data[0]: sim_preds.append(sim_preds0[j].split('$')[0]+'$') else: sim_preds.append(sim_preds1[j].split('$')[0][-1::-1]+'$') else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev, test=True) cost = criterion(preds, text) _, preds = preds.max(1) preds = preds.view(-1) sim_preds = converter.decode(preds.data, length.data) loss_avg.add(cost) for pred, target in zip(sim_preds, cpu_texts): if pred == target.lower(): n_correct += 1 f.write("预测 %s 目标 %s\n" % ( pred,target ) ) distance += levenshtein(pred,target) / max(len(pred),len(target)) n_total += 1 f.close() print("correct / total: %d / %d, " % (n_correct, n_total)) print('levenshtein distance: %f' % (distance/n_total)) accuracy = n_correct / float(n_total) print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) return accuracy def trainBatch(): data = train_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev) cost = criterion(preds, text) MORAN.zero_grad() cost.backward() optimizer.step() return cost t0 = time.time() acc = 0 acc_tmp = 0 for epoch in range(opt.niter): train_iter = iter(train_loader) i = 0 while i < len(train_loader): # print("main函数里,可迭代次数为 %d" % len(train_loader)) if i % opt.valInterval == 0: for p in MORAN.parameters(): p.requires_grad = False MORAN.eval() acc_tmp = val(test_dataset, criterion) if acc_tmp > acc: acc = acc_tmp torch.save(MORAN.state_dict(), '{0}/{1}_{2}.pth'.format( opt.experiment, i, str(acc)[:6])) if i % opt.saveInterval == 0: torch.save(MORAN.state_dict(), '{0}/{1}_{2}.pth'.format( opt.experiment, epoch, i)) for p in MORAN.parameters(): p.requires_grad = True MORAN.train() cost = trainBatch() loss_avg.add(cost) if i % opt.displayInterval == 0: t1 = time.time() print ('Epoch: %d/%d; iter: %d/%d; Loss: %f; time: %.2f s;' % (epoch, opt.niter, i, len(train_loader), loss_avg.val(), t1-t0)), loss_avg.reset() t0 = time.time() i += 1
alphabet = '0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:8:t:u:v:w:x:y:z:$' target_height = 32 target_width = 100 cuda_flag = False if torch.cuda.is_available(): cuda_flag = True MORAN = MORAN(1, len(alphabet.split(':')), 256, target_height, target_width, BidirDecoder=True, CUDA=cuda_flag) MORAN = MORAN.cuda() else: MORAN = MORAN(1, len(alphabet.split(':')), 256, target_height, target_width, BidirDecoder=True, inputDataType='torch.FloatTensor', CUDA=cuda_flag) print('loading pretrained model from %s' % model_path) if cuda_flag: state_dict = torch.load(model_path) else: state_dict = torch.load(model_path, map_location='cpu')
def train_HARN(config_file): import sys sys.path.append('./recognition_model/HARN') import argparse import os import random import io import sys import time from models.moran import MORAN import tools.utils as utils import torch.optim as optim import numpy as np import torch.backends.cudnn as cudnn import torch.utils.data import tools.dataset as dataset from torch.autograd import Variable from collections import OrderedDict from tools.logger import logger # from wordlist import result # from wordlistlsvt import result import warnings warnings.filterwarnings('ignore') # os.environ['CUDA_VISIBLE_DEVICES'] = '1' # 指定GPU # os.environ['CUDA_VISIBLE_DEVICES'] = '5' # 指定GPU from yacs.config import CfgNode as CN def read_config_file(config_file): # 用yaml重构配置文件 f = open(config_file) opt = CN.load_cfg(f) return opt opt = read_config_file(config_file) # 获取了yaml文件 print("配置文件", opt) sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') # Modify # opt.alphabet = result assert opt.ngpu == 1, "Multi-GPU training is not supported yet, due to the variant lengths of the text in a batch." if opt.experiment is None: opt.experiment = 'expr' os.system('mkdir {0}'.format(opt.experiment)) opt.manualSeed = random.randint(1, 10000) # fix seed random.seed(opt.manualSeed) np.random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) cudnn.benchmark = True # ---------save logger---------# log = logger('./asrn_se50_OCRdata_50_logger') # log = logger('./logger/asrn_se50_lsvt_50') # # 保存日志的路径 / 需要改 # -----------------------------# if not torch.cuda.is_available(): assert not opt.cuda, 'You don\'t have a CUDA device.' if torch.cuda.is_available() and not opt.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) train_nips_dataset = dataset.lmdbDataset(root=opt.train_nips, transform=dataset.resizeNormalize( (opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=opt.alphabet) assert train_nips_dataset ''' train_cvpr_dataset = dataset.lmdbDataset(root=opt.train_cvpr, transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder) assert train_cvpr_dataset ''' ''' train_dataset = torch.utils.data.ConcatDataset([train_nips_dataset, train_cvpr_dataset]) ''' train_dataset = train_nips_dataset train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batchSize, shuffle=False, sampler=dataset.randomSequentialSampler(train_dataset, opt.batchSize), num_workers=int(opt.workers)) test_dataset = dataset.lmdbDataset(root=opt.valroot, transform=dataset.resizeNormalize( (opt.imgW, opt.imgH)), reverse=opt.BidirDecoder) nclass = len( opt.alphabet.split(opt.sep) ) # 一共有多少类,英文是36,中文就是wordlist,系统只认名字为wordlist.py的文件,记得将需要用的文件改为这个名字 nc = 1 converter = utils.strLabelConverterForAttention( opt.alphabet, opt.sep) # 给每个字一个编号,例如:中(2)国(30)人(65);convert是id和字符之间的转换 criterion = torch.nn.CrossEntropyLoss() if opt.cuda: MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, CUDA=opt.cuda, log=log) else: MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, inputDataType='torch.FloatTensor', CUDA=opt.cuda, log=log) if opt.MORAN != '': print('loading pretrained model from %s' % opt.MORAN) if opt.cuda: state_dict = torch.load(opt.MORAN) else: state_dict = torch.load(opt.MORAN, map_location='cpu') MORAN_state_dict_rename = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") # remove `module.` MORAN_state_dict_rename[name] = v MORAN.load_state_dict(MORAN_state_dict_rename, strict=True) image = torch.FloatTensor(opt.batchSize, nc, opt.imgH, opt.imgW) text = torch.LongTensor(opt.batchSize * 5) text_rev = torch.LongTensor(opt.batchSize * 5) length = torch.IntTensor(opt.batchSize) if opt.cuda: MORAN.cuda() MORAN = torch.nn.DataParallel(MORAN, device_ids=range(opt.ngpu)) image = image.cuda() text = text.cuda() text_rev = text_rev.cuda() criterion = criterion.cuda() image = Variable(image) # 把图片转换成 CUDA 可以识别的 Variable 变量 text = Variable(text) text_rev = Variable(text_rev) length = Variable(length) # loss averager loss_avg = utils.averager() # setup optimizer # 优化器的选择,这里用的Adam if opt.adam: optimizer = optim.Adam(MORAN.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.adadelta: optimizer = optim.Adadelta(MORAN.parameters(), lr=opt.lr) elif opt.sgd: optimizer = optim.SGD(MORAN.parameters(), lr=opt.lr, momentum=0.9) else: optimizer = optim.RMSprop(MORAN.parameters(), lr=opt.lr) def levenshtein(s1, s2): # 莱温斯坦距离,编辑距离的一种 if len(s1) < len(s2): return levenshtein(s2, s1) # len(s1) >= len(s2) if len(s2) == 0: return len(s1) previous_row = range(len(s2) + 1) for i, c1 in enumerate(s1): current_row = [i + 1] for j, c2 in enumerate(s2): insertions = previous_row[j + 1] + 1 deletions = current_row[j] + 1 substitutions = previous_row[j] + (c1 != c2) current_row.append(min(insertions, deletions, substitutions)) previous_row = current_row return previous_row[-1] def val(dataset, criterion, max_iter=10000, steps=None): data_loader = torch.utils.data.DataLoader( dataset, shuffle=False, batch_size=opt.batchSize, num_workers=int(opt.workers)) # opt.batchSize val_iter = iter(data_loader) max_iter = min(max_iter, len(data_loader)) n_correct = 0 n_total = 0 distance = 0.0 loss_avg = utils.averager() # f = open('./log.txt', 'a', encoding='utf-8') for i in range(max_iter): # 设置很大的循环数值(达不到此值就会收敛) data = val_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data # data是dataloader导入的东西 utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=False) # 这个encode是将字符encode成id t_rev, _ = converter.encode(cpu_texts_rev, scanned=False) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev, debug=False, test=True, steps=steps) # 跑模型HARN cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) preds0_prob, preds0 = preds0.max(1) # 取概率最大top1的结果 preds0 = preds0.view(-1) preds0_prob = preds0_prob.view(-1) # 维度的变形(好像是 sim_preds0 = converter.decode(preds0.data, length.data) # 将 id decode为字 preds1_prob, preds1 = preds1.max(1) preds1 = preds1.view(-1) preds1_prob = preds1_prob.view(-1) sim_preds1 = converter.decode(preds1.data, length.data) sim_preds = [] # 预测出来的字 for j in range(cpu_images.size(0)): # 对字典进行处理,把单个字符连成字符串 text_begin = 0 if j == 0 else length.data[:j].sum() if torch.mean(preds0_prob[text_begin:text_begin + len(sim_preds0[j].split('$')[0] + '$')]).item() > \ torch.mean( preds1_prob[text_begin:text_begin + len(sim_preds1[j].split('$')[0] + '$')]).item(): sim_preds.append(sim_preds0[j].split('$')[0] + '$') else: sim_preds.append(sim_preds1[j].split('$')[0][-1::-1] + '$') else: # 用不到的另一种情况 cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev, test=True) cost = criterion(preds, text) _, preds = preds.max(1) preds = preds.view(-1) sim_preds = converter.decode(preds.data, length.data) loss_avg.add(cost) # 计算loss的平均值 for pred, target in zip( sim_preds, cpu_texts ): # 与GroundTruth的对比,cpu_texts是GroundTruth,sim_preds是连接起来的字符串 if pred == target.lower(): # 完全匹配量 n_correct += 1 # f.write("pred %s\t target %s\n" % (pred, target)) distance += levenshtein(pred, target) / max( len(pred), len(target)) # 莱温斯坦距离 n_total += 1 # 完成了一个单词 # f.close() # print and save # 跑完之后输出到日志中 for pred, gt in zip(sim_preds, cpu_texts): gt = ''.join(gt.split(opt.sep)) print('%-20s, gt: %-20s' % (pred, gt)) print("correct / total: %d / %d, " % (n_correct, n_total)) print('levenshtein distance: %f' % (distance / n_total)) accuracy = n_correct / float(n_total) log.scalar_summary('Validation/levenshtein distance', distance / n_total, steps) log.scalar_summary('Validation/loss', loss_avg.val(), steps) log.scalar_summary('Validation/accuracy', accuracy, steps) print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) return accuracy def trainBatch(steps): data = train_iter.next() if opt.BidirDecoder: cpu_images, cpu_texts, cpu_texts_rev = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) t_rev, _ = converter.encode(cpu_texts_rev, scanned=True) utils.loadData(text, t) utils.loadData(text_rev, t_rev) utils.loadData(length, l) preds0, preds1 = MORAN(image, length, text, text_rev) cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0)) else: cpu_images, cpu_texts = data utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts, scanned=True) utils.loadData(text, t) utils.loadData(length, l) preds = MORAN(image, length, text, text_rev) cost = criterion(preds, text) MORAN.zero_grad() cost.backward() # 反向传播 optimizer.step() # 优化器 return cost t0 = time.time() acc, acc_tmp = 0, 0 print(' === HARN === ') for epoch in range(opt.niter): print(" === Loading Train Data === ") train_iter = iter(train_loader) i = 0 print(" === start training === ") while i < len(train_loader): # len():数据大小 # print("main函数里,可迭代次数为 %d" % len(train_loader)) steps = i + epoch * len(train_loader) # step用来计算什么时候进行存储/打印 if steps % opt.valInterval == 0: for p in MORAN.parameters(): p.requires_grad = False MORAN.eval() print('---------------Please Waiting----------------' ) # train的一些打印信息 acc_tmp = val(test_dataset, criterion, steps=steps) if acc_tmp > acc: acc = acc_tmp try: time.sleep(0.01) torch.save( MORAN.state_dict(), '{0}/{1}_{2}.pth'.format(opt.experiment, i, str(acc)[:6])) print(".pth") except RuntimeError: print("RuntimeError") pass for p in MORAN.parameters(): p.requires_grad = True MORAN.train() cost = trainBatch(steps) loss_avg.add(cost) t1 = time.time() # niter是参数部分设置的epoch数量 print('Epoch: %d/%d; iter: %d/%d; Loss: %f; time: %.2f s;' % (epoch, opt.niter, i, len(train_loader), loss_avg.val(), t1 - t0)), log.scalar_summary('train loss', loss_avg.val(), steps) # 拟合到90多/拟合到1,完全收敛,训练充分 log.scalar_summary('speed batches/persec', steps / (time.time() - t0), steps) loss_avg.reset() t0 = time.time() ''' if i % 100 == 0: t1 = time.time() # niter是参数部分设置的epoch数量 print('Epoch: %d/%d; iter: %d/%d; Loss: %f; time: %.2f s;' % (epoch, opt.niter, i, len(train_loader), loss_avg.val(), t1 - t0)), log.scalar_summary('train loss', loss_avg.val(), i) # 拟合到90多/拟合到1,完全收敛,训练充分 log.scalar_summary('speed batches/persec', i / (time.time() - t0), i) loss_avg.reset() t0 = time.time() ''' ''' if steps % opt.displayInterval == 0: t1 = time.time() # niter是参数部分设置的epoch数量 print('Epoch: %d/%d; iter: %d/%d; Loss: %f; time: %.2f s;' % (epoch, opt.niter, i, len(train_loader), loss_avg.val(), t1 - t0)), log.scalar_summary('train loss', loss_avg.val(), steps) # 拟合到90多/拟合到1,完全收敛,训练充分 log.scalar_summary('speed batches/persec', steps / (time.time() - t0), steps) loss_avg.reset() t0 = time.time() ''' i += 1
class text_recognize(object): def __init__(self): r = rospkg.RosPack() self.path = r.get_path('moran_text_recog') self.prob_threshold = 0.90 self.cv_bridge = CvBridge() self.commodity_list = [] self.read_commodity(r.get_path('text_msgs') + "/config/commodity_list.txt") self.alphabet = '0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$' self.means = (0.485, 0.456, 0.406) self.stds = (0.229, 0.224, 0.225) self.bbox_thres = 1500 self.color_map = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,255,255)] # 0 90 180 270 noise self.objects = [] self.is_compressed = False self.cuda_use = torch.cuda.is_available() if self.cuda_use: cuda_flag = True self.network = MORAN(1, len(self.alphabet.split(':')), 256, 32, 100, BidirDecoder=True, CUDA=cuda_flag) self.network = self.network.cuda() else: self.network = MORAN(1, len(self.alphabet.split(':')), 256, 32, 100, BidirDecoder=True, inputDataType='torch.FloatTensor', CUDA=cuda_flag) model_name = "moran.pth" print "Moran Model Parameters number: " + str(self.count_parameters(self.network)) if self.cuda_use: state_dict = torch.load(os.path.join(self.path, "weights/", model_name)) else: state_dict = torch.load(os.path.join(self.path, "weights/", model_name), map_location='cpu') MORAN_state_dict_rename = OrderedDict() for k, v in state_dict.items(): name = k.replace("module.", "") # remove `module.` MORAN_state_dict_rename[name] = v self.network.load_state_dict(MORAN_state_dict_rename) self.converter = utils.strLabelConverterForAttention(self.alphabet, ':') self.transformer = dataset.resizeNormalize((100, 32)) for p in self.network.parameters(): p.requires_grad = False self.network.eval() #### Publisher self.image_pub = rospy.Publisher("~predict_img", Image, queue_size = 1) self.mask = rospy.Publisher("~mask", Image, queue_size = 1) self.img_bbox_pub = rospy.Publisher("~predict_bbox", Image, queue_size = 1) #### Service self.predict_ser = rospy.Service("~text_recognize_server", text_recognize_srv, self.srv_callback) image_sub1 = rospy.Subscriber('/text_detection_array', text_detection_array, self.callback, queue_size = 1) ### msg filter # image_sub = message_filters.Subscriber('/camera/color/image_raw', Image) # depth_sub = message_filters.Subscriber('/camera/aligned_depth_to_color/image_raw', Image) # ts = message_filters.TimeSynchronizer([image_sub, depth_sub], 10) # ts.registerCallback(self.callback) print "============ Ready ============" def read_commodity(self, path): for line in open(path, "r"): line = line.rstrip('\n') self.commodity_list.append(line) print "Node (text_recognize): Finish reading list" def count_parameters(self, model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def callback(self, msg): try: if self.is_compressed: np_arr = np.fromstring(msg.image, np.uint8) cv_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) else: cv_image = self.cv_bridge.imgmsg_to_cv2(msg.image, "bgr8") except CvBridgeError as e: print(e) predict_img, mask = self.predict(msg, cv_image) img_bbox = cv_image.copy() try: self.image_pub.publish(self.cv_bridge.cv2_to_imgmsg(predict_img, "bgr8")) self.img_bbox_pub.publish(self.cv_bridge.cv2_to_imgmsg(img_bbox, "bgr8")) self.mask.publish(self.cv_bridge.cv2_to_imgmsg(mask, "8UC1")) except CvBridgeError as e: print(e) def srv_callback(self, req): resp = text_recognize_srvResponse() try: if self.is_compressed: np_arr = np.fromstring(req.data.image, np.uint8) cv_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) else: cv_image = self.cv_bridge.imgmsg_to_cv2(req.data.image, "bgr8") except CvBridgeError as e: resp.state = e print(e) predict_img, mask = self.predict(req.data, cv_image, req.direct) img_bbox = cv_image.copy() try: self.image_pub.publish(self.cv_bridge.cv2_to_imgmsg(predict_img, "bgr8")) self.img_bbox_pub.publish(self.cv_bridge.cv2_to_imgmsg(img_bbox, "bgr8")) resp.mask = self.cv_bridge.cv2_to_imgmsg(mask, "8UC1") self.mask.publish(self.cv_bridge.cv2_to_imgmsg(mask, "8UC1")) except CvBridgeError as e: resp.state = e print(e) return resp def predict(self, msg, img, rot=0): # # Preprocessing gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) (rows, cols, channels) = img.shape mask = np.zeros([rows, cols], dtype = np.uint8) for text_bb in msg.text_array: if (text_bb.box.ymax - text_bb.box.ymin) * (text_bb.box.xmax - text_bb.box.xmin) < self.bbox_thres: continue start = time.time() image = gray[text_bb.box.ymin:text_bb.box.ymax, text_bb.box.xmin:text_bb.box.xmax] image = Im.fromarray(image) image = self.transformer(image) if self.cuda_use: image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) text = torch.LongTensor(1 * 5) length = torch.IntTensor(1) text = Variable(text) length = Variable(length) max_iter = 20 t, l = self.converter.encode('0'*max_iter) utils.loadData(text, t) utils.loadData(length, l) output = self.network(image, length, text, text, test=True, debug=True) preds, preds_reverse = output[0] demo = output[1] _, preds = preds.max(1) _, preds_reverse = preds_reverse.max(1) sim_preds = self.converter.decode(preds.data, length.data) sim_preds = sim_preds.strip().split('$')[0] sim_preds_reverse = self.converter.decode(preds_reverse.data, length.data) sim_preds_reverse = sim_preds_reverse.strip().split('$')[0] # print('\nResult:\n' + 'Left to Right: ' + sim_preds + '\nRight to Left: ' + sim_preds_reverse + '\n\n') print "Text Recognize Time : {}".format(time.time() - start) _cont = [] for p in text_bb.contour: point = [] point.append(p.point[0]) point.append(p.point[1]) _cont.append(point) _cont = np.array(_cont, np.int32) if sim_preds in self.commodity_list: cv2.rectangle(img, (text_bb.box.xmin, text_bb.box.ymin),(text_bb.box.xmax, text_bb.box.ymax), self.color_map[rot], 3) cv2.putText(img, sim_preds, (text_bb.box.xmin, text_bb.box.ymin), 0, 1, (0, 255, 255),3) pix = self.commodity_list.index(sim_preds) + rot*len(self.commodity_list) if pix in np.unique(mask): cv2.fillConvexPoly(mask, _cont, pix + 4*len(self.commodity_list)) else: cv2.fillConvexPoly(mask, _cont, pix) else: correct, conf, _bool = self.conf_of_word(sim_preds) # print conf if _bool: cv2.putText(img, correct + "{:.2f}".format(conf), (text_bb.box.xmin, text_bb.box.ymin), 0, 1, (0, 255, 255),3) cv2.rectangle(img, (text_bb.box.xmin, text_bb.box.ymin),(text_bb.box.xmax, text_bb.box.ymax), (255, 255, 255), 2) pix = self.commodity_list.index(correct) + rot*len(self.commodity_list) if pix in np.unique(mask): cv2.fillConvexPoly(mask, _cont, pix + 4*len(self.commodity_list)) else: cv2.fillConvexPoly(mask, _cont, pix) # else: # cv2.putText(img, sim_preds, (text_bb.box.xmin, text_bb.box.ymin), 0, 1, (0, 0, 0),3) # cv2.rectangle(img, (text_bb.box.xmin, text_bb.box.ymin),(text_bb.box.xmax, text_bb.box.ymax), (0, 0, 0), 2) return img, mask def conf_of_word(self, target): ### Edit distance # print target _recheck = False total = np.zeros(len(self.commodity_list)) for i in range(1, len(self.commodity_list)): size_x = len(self.commodity_list[i]) + 1 size_y = len(target) + 1 matrix = np.zeros ((size_x, size_y)) for x in xrange(size_x): matrix [x, 0] = x for y in xrange(size_y): matrix [0, y] = y for x in xrange(1, size_x): for y in xrange(1, size_y): if self.commodity_list[i][x-1] == target[y-1]: matrix [x,y] = min( matrix[x-1, y] + 1, matrix[x-1, y-1], matrix[x, y-1] + 1 ) else: matrix [x,y] = min( matrix[x-1,y] + 1, matrix[x-1,y-1] + 1, matrix[x,y-1] + 1 ) # print (matrix) total[i] = (size_x - matrix[size_x-1, size_y-1]) / float(size_x) if self.commodity_list[i] == "kleenex" and 0.3 < total[i] < 0.77: _list = ["kloonex", "kloonox","kleeper", "killer", "kleem", "kleers", "kluting", "klates",\ "kleams", "kreamer", "klea", "kleas", "kletter","keenier","vooney", "wooner", "whonex"] _recheck = True elif self.commodity_list[i] == "andes" and 0.3 < total[i] < 0.77: _list = ["anders", "findes","windes"] # "andor", _recheck = True elif self.commodity_list[i] == "vanish" and 0.3 < total[i] < 0.77: _list = ["varish"] _recheck = True # elif self.commodity_list[i] == "crayola" and 0.3 < total[i] < 0.77: # _list = ["casions"] # _recheck = True if _recheck == True: for _str in _list: size_x = len(_str) + 1 size_y = len(target) + 1 matrix = np.zeros ((size_x, size_y)) for x in xrange(size_x): matrix [x, 0] = x for y in xrange(size_y): matrix [0, y] = y for x in xrange(1, size_x): for y in xrange(1, size_y): if _str[x-1] == target[y-1]: matrix [x,y] = min( matrix[x-1, y] + 1, matrix[x-1, y-1], matrix[x, y-1] + 1 ) else: matrix [x,y] = min( matrix[x-1,y] + 1, matrix[x-1,y-1] + 1, matrix[x,y-1] + 1 ) score_temp = (size_x - matrix[size_x-1, size_y-1]) / float(size_x) if total[i] < score_temp: total[i] = score_temp if 0.77 > total[i] > 0.68: total[i] = 0.77 _recheck = False # print target, total[i], self.commodity_list[i] return self.commodity_list[np.argmax(total)], np.max(total), np.max(total) >= 0.77 ## 0.66 ### old method # total = np.zeros(len(self.commodity_list)) # for i in range(1, len(self.commodity_list)): # # if self.commodity_list[i] != "raisins": # # continue # err = 0 ## error # _len = len(self.commodity_list[i]) # arr = -10 * np.ones(_len) # for j in range(len(target)): # index = self.commodity_list[i].find(target[j]) # if index == -1: # err += 1 # else: # upper = arr[index+1] if index != _len - 1 else -10 # if arr[index] == -10 and upper == -10: # arr[index] = j # else: # index = self.commodity_list[i].find(target[j], index + 1) # while index != -1: # lower = arr[index-1] if index != 0 else -10 # upper = arr[index+1] if index != _len - 1 else -10 # if (arr[index] - lower) == 1 or (upper - arr[index]) == 1: # index = self.commodity_list[i].find(target[j], index + 1) # else: # arr[index] = j # break # score = 0 # score for word # for j in range(_len - 1): # if arr[j+1] - arr[j] == 1: # score += 1 # total[i] = float(score) / (_len + err - 1) # # print score, _len, err, arr # return self.commodity_list[np.argmax(total)], np.max(total), np.max(total) >= 0.5 def onShutdown(self): rospy.loginfo("Shutdown.")