def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): print("Saving model and optimizer state at iteration {} to {}".format( iteration, filepath)) model_for_saving = CRNN(**CRNN_config).cuda() model_for_saving.load_state_dict(model.state_dict()) torch.save( { 'model': model_for_saving, 'iteration': iteration, 'optimizer': optimizer.state_dict(), 'learning_rate': learning_rate }, filepath)
val_loader = torch.utils.data.DataLoader(testdataset, shuffle=False, batch_size=opt.batch_size, num_workers=int(opt.workers), collate_fn=alignCollate( imgH=imgH, imgW=imgW, keep_ratio=keep_ratio)) alphabet = keys.alphabetChinese print("char num ", len(alphabet)) model = CRNN(32, 1, len(alphabet) + 1, 256, 1) converter = strLabelConverter(''.join(alphabet)) state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth", map_location=lambda storage, loc: storage) new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k if "num_batches_tracked" not in k: # name = name.replace('module.', '') # remove `module.` new_state_dict[name] = v model.cuda() model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) # load params model.load_state_dict(new_state_dict) model.eval() curAcc = val(model, converter, val_loader, max_iter=5)
import dataset from PIL import Image from models.crnn import CRNN model_path = './data/crnn.pth' img_path = './data/demo.png' alphabet = '0123456789abcdefghijklmnopqrstuvwxyz' model = CRNN(32, 1, 37, 256) if torch.cuda.is_available(): model = model.cuda() print('loading pretrained model from %s' % model_path) model.load_state_dict(torch.load(model_path)) converter = utils.strLabelConverter(alphabet) transformer = dataset.resizeNormalize((100, 32)) image = Image.open(img_path).convert('L') image = transformer(image) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) model.eval() preds = model(image) _, preds = preds.max(2)
class Demo(object): def __init__(self, args): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus self.args = args self.alphabet = alphabetChinese nclass = len(self.alphabet) + 1 nc = 1 self.net = CRNN(args.imgH, nc, args.nh, nclass) self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False) self.transformer = resizeNormalize(args.imgH) print('loading pretrained model from %s' % args.model_path) checkpoint = torch.load(args.model_path) if 'model_state_dict' in checkpoint.keys(): checkpoint = checkpoint['model_state_dict'] from collections import OrderedDict model_dict = OrderedDict() for k, v in checkpoint.items(): if 'module' in k: model_dict[k[7:]] = v else: model_dict[k] = v self.net.load_state_dict(model_dict) if args.cuda and torch.cuda.is_available(): print('available gpus is,', torch.cuda.device_count()) self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda() self.net.eval() def predict(self, image): image = self.transformer(image) if torch.cuda.is_available(): image = image.cuda() image = image.view(1, *image.size()) image = Variable(image) preds = self.net(image) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True) sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False) print('%-20s => %-20s' % (raw_pred, sim_pred)) return sim_pred def predict_batch(self, images): N = len(images) n_batch = N // self.args.batch_size n_batch += 1 if N % self.args.batch_size else 0 res = [] for i in range(n_batch): batch = images[i*self.args.batch_size : min((i+1)*self.args.batch_size, N)] maxW = 0 for i in range(len(batch)): batch[i] = self.transformer(batch[i]) imgW = batch[i].shape[2] maxW = max(maxW, imgW) for i in range(len(batch)): if batch[i].shape[2] < maxW: batch[i] = torch.cat((batch[i], torch.zeros((1, self.args.imgH, maxW-batch[i].shape[2]), dtype=batch[i].dtype)), 2) batch_imgs = torch.cat([t.unsqueeze(0) for t in batch], 0) preds = self.net(batch_imgs) preds_size = Variable(torch.IntTensor([preds.size(0)]*len(batch))) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True) sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) for raw_pred, sim_pred in zip(raw_preds, sim_preds): print('%-20s => %-20s' % (raw_pred, sim_pred)) res.extend(sim_preds) return res def inference(self, image_path, batch_pred=False): if os.path.isdir(image_path): file_list = os.listdir(image_path) image_list = [os.path.join(image_path, i) for i in file_list if i.rsplit('.')[-1].lower() in img_types] else: image_list = [image_path] res = [] images = [] for img_path in image_list: image = Image.open(img_path).convert('L') if not batch_pred: sim_pred = self.predict(image) res.append(sim_pred) else: images.append(image) if batch_pred and images: res = self.predict_batch(images) return res
# Load SSD model PATH_TO_FROZEN_GRAPH = args.detection_model_path detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f: od_graph_def.ParseFromString(f.read()) tf.import_graph_def(od_graph_def, name='') # Load CRNN model alphabet = '0123456789abcdefghijklmnopqrstuvwxyz' crnn = CRNN(32, 1, 37, 256) if torch.cuda.is_available(): crnn = crnn.cuda() crnn.load_state_dict(torch.load(args.recognition_model_path)) converter = utils.strLabelConverter(alphabet) transformer = dataset.resizeNormalize((100, 32)) crnn.eval() # Open a video file or an image file cap = cv2.VideoCapture(args.input if args.input else 0) while cv2.waitKey(1) < 0: has_frame, frame = cap.read() if not has_frame: cv2.waitKey(0) break im_height, im_width, _ = frame.shape tf_frame = np.expand_dims(frame, axis=0)
def main(): # Set parameters of the trainer global args, device args = parse_args() print('=' * 60) print(args) print('=' * 60) random.seed(args.manual_seed) np.random.seed(args.manual_seed) torch.manual_seed(args.manual_seed) if torch.cuda.is_available() and not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) if args.cuda and torch.cuda.is_available(): device = torch.device('cuda') #cudnn.benchmark = True else: device = torch.device('cpu') # load alphabet from file if os.path.isfile(args.alphabet): alphabet = '' with open(args.alphabet, mode='rb') as f: for line in f.readlines(): alphabet += line.decode('utf-8')[0] args.alphabet = alphabet converter = utils.CTCLabelConverter(args.alphabet, ignore_case=False) # data loader image_size = (args.image_h, args.image_w) collater = DatasetCollater(image_size, keep_ratio=args.keep_ratio) train_dataset = Dataset(mode='train', data_root=args.data_root, transform=None) #sampler = RandomSequentialSampler(train_dataset, args.batch_size) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collater, shuffle=True, num_workers=args.workers) val_dataset = Dataset(mode='val', data_root=args.data_root, transform=None) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collater, shuffle=True, num_workers=args.workers) # network num_classes = len(args.alphabet) + 1 num_channels = 1 if args.arch == 'crnn': model = CRNN(args.image_h, num_channels, num_classes, args.num_hidden) elif args.arch == 'densenet': model = DenseNet( num_channels=num_channels, num_classes=num_classes, growth_rate=12, block_config=(3, 6, 9), #(3,6,12,16), compression=0.5, num_init_features=64, bn_size=4, rnn=args.rnn, num_hidden=args.num_hidden, drop_rate=0, small_inputs=True, efficient=False) else: raise ValueError('unknown architecture {}'.format(args.arch)) model = model.to(device) summary(model, torch.zeros((2, 1, 32, 650)).to(device)) #print('='*60) #print(model) #print('='*60) # loss criterion = CTCLoss() criterion = criterion.to(device) # setup optimizer if args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) elif args.optimizer == 'adadelta': optimizer = optim.Adadelta(model.parameters(), weight_decay=args.weight_decay) else: raise ValueError('unknown optimizer {}'.format(args.optimizer)) print('=' * 60) print(optimizer) print('=' * 60) # Define learning rate decay schedule global scheduler #exp_decay = math.exp(-0.1) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate) #step_size = 10000 #gamma_decay = 0.8 #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_decay) #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma_decay) # initialize model if args.pretrained and os.path.isfile(args.pretrained): print(">> Using pre-trained model '{}'".format( os.path.basename(args.pretrained))) state_dict = torch.load(args.pretrained) model.load_state_dict(state_dict) print("loading pretrained model done.") global is_best, best_accuracy is_best = False best_accuracy = 0.0 start_epoch = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): # load checkpoint weights and update model and optimizer print(">> Loading checkpoint:\n>> '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_accuracy = checkpoint['best_accuracy'] print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format( args.resume, start_epoch)) model.load_state_dict(checkpoint['state_dict']) #optimizer.load_state_dict(checkpoint['optimizer']) # important not to forget scheduler updating #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate, last_epoch=start_epoch - 1) else: print(">> No checkpoint found at '{}'".format(args.resume)) # Create export dir if it doesnt exist checkpoint = "{}".format(args.arch) checkpoint += "_{}".format(args.optimizer) checkpoint += "_lr_{}".format(args.lr) checkpoint += "_decay_rate_{}".format(args.decay_rate) checkpoint += "_bsize_{}".format(args.batch_size) checkpoint += "_height_{}".format(args.image_h) checkpoint += "_keep_ratio" if args.keep_ratio else "_width_{}".format( image_size[1]) args.checkpoint = os.path.join(args.checkpoint, checkpoint) if not os.path.exists(args.checkpoint): os.makedirs(args.checkpoint) print('start training...') for epoch in range(start_epoch, args.max_epoch): # Aujust learning rate for each epoch scheduler.step() # Train for one epoch on train set _ = train(train_loader, val_loader, model, criterion, optimizer, epoch, converter)
def main(): config = Config() if not os.path.exists(config.expr_dir): os.makedirs(config.expr_dir) if torch.cuda.is_available() and not config.use_cuda: print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True") # 加载训练数据集 train_dataset = HubDataset(config, "train", transform=None) train_kwargs = {'num_workers': 2, 'pin_memory': True, 'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {} training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs) # 加载定长校验数据集 eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)])) eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {} eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs) # 加载不定长校验数据集 # eval_dataset = HubDataset(config, "eval") # eval_kwargs = {'num_workers': 2, 'pin_memory': False, # 'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {} # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs) # 定义网络模型 nclass = len(config.label_classes) + 1 crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers) # 加载预训练模型 if config.pretrained != '': print('loading pretrained model from %s' % config.pretrained) crnn.load_state_dict(torch.load(config.pretrained)) print(crnn) # Compute average for `torch.Variable` and `torch.Tensor`. loss_avg = utils.averager() # Convert between str and label. converter = utils.strLabelConverter(config.label_classes) criterion = CTCLoss() # 定义损失函数 # 设置占位符 image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height) text = torch.LongTensor(config.train_batch_size * 5) length = torch.LongTensor(config.train_batch_size) if config.use_cuda and torch.cuda.is_available(): criterion = criterion.cuda() image = image.cuda() crnn = crnn.to(config.device) image = Variable(image) text = Variable(text) length = Variable(length) # 设定优化器 if config.adam: optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999)) elif config.adadelta: optimizer = optim.Adadelta(crnn.parameters()) else: optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr) def val(net, criterion, eval_data_batch): print('Start val') for p in crnn.parameters(): p.requires_grad = False net.eval() n_correct = 0 loss_avg_eval = utils.averager() for data in eval_data_batch: cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) utils.loadData(text, t) utils.loadData(length, l) preds = crnn(image) preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size loss_avg_eval.add(cost) # 计算loss _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = converter.decode(preds.data, preds_size.data, raw=False) cpu_texts_decode = [] for i in cpu_texts: cpu_texts_decode.append(i) for pred, target in zip(sim_preds, cpu_texts_decode): # 计算准确率 if pred == target: n_correct += 1 raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp] for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode): print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) accuracy = n_correct / float(len(eval_dataset)) print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) # 训练每个batch数据 def train(net, criterion, optimizer, data): cpu_images, cpu_texts = data batch_size = cpu_images.size(0) # 计算当前batch_size大小 utils.loadData(image, cpu_images) t, l = converter.encode(cpu_texts) # 转换为类别 utils.loadData(text, t) utils.loadData(length, l) optimizer.zero_grad() # 清零梯度 preds = net(image) preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size cost.backward() optimizer.step() return cost for epoch in range(config.nepoch): i = 0 for batch_data in training_data_batch: for p in crnn.parameters(): p.requires_grad = True crnn.train() cost = train(crnn, criterion, optimizer, batch_data) loss_avg.add(cost) i += 1 if i % config.displayInterval == 0: print('[%d/%d][%d/%d] Loss: %f' % (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val())) loss_avg.reset() # if i % config.valInterval == 0: # val(crnn, criterion, eval_data_batch) # # # do checkpointing # if i % config.saveInterval == 0: # torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i)) val(crnn, criterion, eval_data_batch) torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))
random.seed(opt.manualSeed) np.random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) transformer = dataset.resizeNormalize((100, 32)) nclass = len(opt.alphabet) + 1 nc = 1 converter = misc.strLabelConverter(opt.alphabet) crnn = CRNN(opt.imgH, nc, nclass, opt.nh) if opt.pretrained != '': print('loading pretrained model from %s' % opt.pretrained) crnn.load_state_dict(load_multi(opt.pretrained), strict=False) # Process pruned conv2d and batchnorm2d, store them in a dictionary crnn_l = list(crnn.cnn._modules.items()) last_channels = [0] crnn_new = CRNN(opt.imgH, nc, nclass, opt.nh) crnn_new = copy.deepcopy(crnn) new_dict = {} for i in range(len(crnn_l)): module = crnn_l[i][1] if isinstance(module, torch.nn.Conv2d): out_channels = get_out_channel(module) new = set_weight_conv(last_channels, out_channels, module)
class Trainer(object): def __init__(self): os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus if args.chars_file == '': self.alphabet = alphabetChinese else: self.alphabet = utils.load_chars(args.chars_file) nclass = len(self.alphabet) + 1 nc = 1 self.net = CRNN(args.imgH, nc, args.nh, nclass) self.train_dataloader, self.val_dataloader = self.dataloader( self.alphabet) self.criterion = CTCLoss() self.optimizer = self.get_optimizer() self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False) self.best_acc = 0.00001 model_name = '%s' % (args.dataset_name) if not os.path.exists(args.save_prefix): os.mkdir(args.save_prefix) args.save_prefix += model_name if args.pretrained != '': print('loading pretrained model from %s' % args.pretrained) checkpoint = torch.load(args.pretrained) if 'model_state_dict' in checkpoint.keys(): # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) args.start_epoch = checkpoint['epoch'] self.best_acc = checkpoint['best_acc'] checkpoint = checkpoint['model_state_dict'] from collections import OrderedDict model_dict = OrderedDict() for k, v in checkpoint.items(): if 'module' in k: model_dict[k[7:]] = v else: model_dict[k] = v self.net.load_state_dict(model_dict) if not args.cuda and torch.cuda.is_available(): print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) elif args.cuda and torch.cuda.is_available(): print('available gpus is ', torch.cuda.device_count()) self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda() self.criterion = self.criterion.cuda() def dataloader(self, alphabet): # train_transform = transforms.Compose( # [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), # resizeNormalize(args.imgH)]) # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform) train_dataset = NumDataset(args.train_dir, alphabet, transform=resizeNormalize(args.imgH)) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) if os.path.exists(args.val_dir): # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH)) val_dataset = NumDataset(args.val_dir, alphabet, mode='test', transform=resizeNormalize(args.imgH)) val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) else: val_dataloader = None return train_dataloader, val_dataloader def get_optimizer(self): if args.optimizer == 'sgd': optimizer = optim.SGD( self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, ) elif args.optimizer == 'adam': optimizer = optim.Adam( self.net.parameters(), lr=args.lr, betas=(args.beta1, 0.999), ) else: optimizer = optim.RMSprop( self.net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, ) return optimizer def train(self): logging.basicConfig() logger = logging.getLogger() logger.setLevel(logging.INFO) log_file_path = args.save_prefix + '_train.log' log_dir = os.path.dirname(log_file_path) if log_dir and not os.path.exists(log_dir): os.mkdir(log_dir) fh = logging.FileHandler(log_file_path) logger.addHandler(fh) logger.info(args) logger.info('Start training from [Epoch {}]'.format(args.start_epoch + 1)) losses = utils.Averager() train_accuracy = utils.Averager() for epoch in range(args.start_epoch, args.nepoch): self.net.train() btic = time.time() for i, (imgs, labels) in enumerate(self.train_dataloader): batch_size = imgs.size()[0] imgs = imgs.cuda() preds = self.net(imgs).cpu() text, length = self.converter.encode( labels ) # length 一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标 preds_size = torch.IntTensor([preds.size(0)] * batch_size) loss_avg = self.criterion(preds, text, preds_size, length) / batch_size self.optimizer.zero_grad() loss_avg.backward() self.optimizer.step() losses.update(loss_avg.item(), batch_size) _, preds_m = preds.max(2) preds_m = preds_m.transpose(1, 0).contiguous().view(-1) sim_preds = self.converter.decode(preds_m.data, preds_size.data, raw=False) n_correct = 0 for pred, target in zip(sim_preds, labels): if pred == target: n_correct += 1 train_accuracy.update(n_correct, batch_size, MUL_n=False) if args.log_interval and not (i + 1) % args.log_interval: logger.info( '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}' .format(epoch + 1, args.nepoch, i + 1, len(self.train_dataloader), batch_size / (time.time() - btic), losses.val())) losses.reset() logger.info( 'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format( train_accuracy.val(), train_accuracy.sum, train_accuracy.count)) train_accuracy.reset() if args.val_interval and not (epoch + 1) % args.val_interval: acc = self.validate(logger) if acc > self.best_acc: self.best_acc = acc save_path = '{:s}_best.pth'.format(args.save_prefix) torch.save( { 'epoch': epoch, 'model_state_dict': self.net.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict(), 'best_acc': self.best_acc, }, save_path) logging.info("best acc is:{:.3f}".format(self.best_acc)) if args.save_interval and not (epoch + 1) % args.save_interval: save_path = '{:s}_{:04d}_{:.3f}.pth'.format( args.save_prefix, epoch + 1, acc) torch.save( { 'epoch': epoch, 'model_state_dict': self.net.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict(), 'best_acc': self.best_acc, }, save_path) def validate(self, logger): if self.val_dataloader is None: return 0 logger.info('Start validate.') losses = utils.Averager() self.net.eval() n_correct = 0 with torch.no_grad(): for i, (imgs, labels) in enumerate(self.val_dataloader): batch_size = imgs.size()[0] imgs = imgs.cuda() preds = self.net(imgs).cpu() text, length = self.converter.encode( labels ) # length 一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标 preds_size = torch.IntTensor( [preds.size(0)] * batch_size) # timestep * batchsize loss_avg = self.criterion(preds, text, preds_size, length) / batch_size losses.update(loss_avg.item(), batch_size) _, preds = preds.max(2) preds = preds.transpose(1, 0).contiguous().view(-1) sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False) for pred, target in zip(sim_preds, labels): if pred == target: n_correct += 1 accuracy = n_correct / float(losses.count) logger.info( 'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]' .format(losses.val(), accuracy, n_correct, losses.count)) return accuracy