def load_model(self): if 'CTC' in self.opt.Prediction: self.ctc_converter = CTCLabelConverter(self.opt.character) self.opt.ctc_num_class = len(self.ctc_converter.character) self.opt.num_class = self.opt.ctc_num_class + 1 else: self.attn_converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.attn_converter.character) self.opt.ctc_num_class = self.opt.num_class - 1 if self.opt.rgb: self.opt.input_channel = 3 self.model = Model(self.opt) print('model input parameters', self.opt.imgH, self.opt.imgW, self.opt.num_fiducial, self.opt.input_channel, self.opt.output_channel, self.opt.hidden_size, self.opt.num_class, self.opt.batch_max_length, self.opt.Transformation, self.opt.FeatureExtraction, self.opt.SequenceModeling, self.opt.Prediction) print(f"=====Use {self.opt.Prediction} prediction result=====") self.model = torch.nn.DataParallel(self.model) if torch.cuda.is_available(): self.model = self.model.cuda() # load model print('loading pretrained model from %s' % self.opt.saved_model) if torch.cuda.is_available(): self.model.load_state_dict(torch.load(self.opt.saved_model)) else: self.model.load_state_dict( torch.load(self.opt.saved_model, map_location="cpu")) self.model.eval()
def test(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).cuda() # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model)) opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) # print(model) """ keep evaluation model and result logs """ os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: criterion = torch.nn.CrossEntropyLoss( ignore_index=0).cuda() # ignore [GO] token = ignore index 0 """ evaluation """ model.eval() if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets benchmark_all_eval(model, criterion, converter, opt) else: AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) evaluation_loader = torch.utils.data.DataLoader( eval_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_evaluation, pin_memory=True) _, accuracy_by_best_model, _, _, _, _, _ = validation( model, criterion, evaluation_loader, converter, opt) print(accuracy_by_best_model) with open( './result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log: log.write(str(accuracy_by_best_model) + '\n')
def __init__(self, model_path, alphabet, image_size, device=None): self._alphabet = alphabet self._image_size = image_size self._device = device self._transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, ), std=(0.5, )) ]) alphabet = '' with open(self._alphabet, mode='rb') as f: for line in f.readlines(): alphabet += line.decode('utf-8')[0] self._alphabet = alphabet self._converter = CTCLabelConverter(self._alphabet, ignore_case=False) num_classes = len(self._alphabet) + 1 num_channels = 1 self._model = DenseNetRNN( num_channels=num_channels, num_classes=num_classes, rnn=True, num_hidden=256, growth_rate=12, block_config=(3, 6, 9), #(3,6,12,16), compression=0.5, num_init_features=64, bn_size=4, drop_rate=0, small_inputs=True, efficient=False) for param in self._model.parameters(): param.requires_grad = False state_dict = torch.load(model_path, map_location=device) self._model.load_state_dict(state_dict['state_dict']) self.mode = self._model.to(self._device) self._model.eval()
def train(opt): """ dataset preparation """ opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU model = torch.nn.DataParallel(model).cuda() model.train() if opt.continue_model != '': if opt.without_prediction: load_model_without_prediction(opt.continue_model, model) print(f'loading pretrained model from {opt.continue_model}, without prediction layer') else: print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while True: # train part for p in model.parameters(): p.requires_grad = True image_tensors, labels = train_dataset.get_batch() image = image_tensors.cuda() text, length = converter.encode(labels) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # to use CTCLoss format cost = criterion(preds, text, preds_size, length) else: preds = model(image, text) target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time logging.info(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n') loss_avg.reset() model.eval() valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/mtl_best_accuracy.pth') if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' logging.info(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 50000 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: logging.info('end the training') sys.exit() i += 1
def train(opt): #logging.info(opt) train_dataset = Batch_Dataset(opt) AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbDataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) print('-' * 80) """ model configuration """ ctc_converter = CTCLabelConverter(opt.character, opt.subword) attn_converter = AttnLabelConverter(opt.character, opt.subword, opt.batch_max_length) opt.num_class = len(attn_converter.character) opt.ctc_num_class = len(ctc_converter.character) print("ctc num class {}".format(len(ctc_converter.character))) print("attention num class {}".format(len(attn_converter.character))) if opt.rgb: opt.input_channel = 3 model = MyModel(opt) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print('Skip {name} as it is already initialized'.format(name)) continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: if 'weight' in name: param.data.fill_(1) continue model = torch.nn.DataParallel(model).to(device) model.train() if opt.continue_model != '': print('loading pretrained model from {}'.format(opt.continue_model)) model.load_state_dict(torch.load(opt.continue_model)) """ setup loss """ ctc_criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) attn_criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) loss_avg = Averager() filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ with open(osj(opt.outPath, '{}/opt.txt'.format(opt.experiment_name)), 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += '{}: {}\n'.format(str(k), str(v)) opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': print('continue to train, start_iter: {}'.format(start_iter)) start_time = time.time() best_accuracy = -1 i = start_iter while True: # train part for p in model.parameters(): p.requires_grad = True image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) ctc_text, ctc_length = ctc_converter.encode(labels) attn_text, attn_length = attn_converter.encode(labels) batch_size = image.size(0) # ctc loss ctc_preds, attn_preds = model(image, attn_text) ctc_preds = ctc_preds.log_softmax(2) preds_size = torch.IntTensor([ctc_preds.size(1)] * batch_size) ctc_preds = ctc_preds.permute(1, 0, 2) ctc_cost = ctc_criterion(ctc_preds, ctc_text, preds_size, ctc_length) # attn loss target = attn_text[:, 1:] attn_cost = attn_criterion(attn_preds.view(-1, attn_preds.shape[-1]), target.contiguous().view(-1)) cost = opt.ctc_weight * ctc_cost + (1.0 - opt.ctc_weight) * attn_cost model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time logging.info('[{}/{}] Loss: {:0.5f} elapsed_time: {:0.5f}'.format( i, opt.num_iter, loss_avg.val(), elapsed_time)) # for log with open( osj(opt.outPath, '{}/log_train.txt'.format(opt.experiment_name)), 'a') as log: log.write( '[{}/{}] Loss: {:0.5f} elapsed_time: {:0.5f}\n'.format( i, opt.num_iter, loss_avg.val(), elapsed_time)) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, ctc_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data \ = mtl_validation(model, ctc_criterion, attn_criterion, valid_loader, ctc_converter, attn_converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print('{:20s}, gt: {:20s}, {}'.format( pred, gt, str(pred == gt))) log.write('{:20s}, gt: {:20s}, {}\n'.format( pred, gt, str(pred == gt))) valid_log = '[{}/{}] valid loss: {:0.5f}'.format( i, opt.num_iter, valid_loss) valid_log += ' accuracy: {:0.3f}'.format(current_accuracy) log.write(valid_log + '\n') # save best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), osj(opt.outPath, '{}/best_accuracy.pth'.format( opt.experiment_name))) best_model_log = 'best_accuracy: {:0.3f}'.format(best_accuracy) logging.info(best_model_log) log.write(best_model_log + '\n') if (i + 1) % 50000 == 0: torch.save( model.state_dict(), osj(opt.outPath, '{}/iter_{}.pth'.format(opt.experiment_name, i + 1))) if i == opt.num_iter: logging.info('end the training') sys.exit() i += 1
class CRNN_OCR(object): def __init__(self, model_path, alphabet, image_size, device=None): self._alphabet = alphabet self._image_size = image_size self._device = device self._transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, ), std=(0.5, )) ]) alphabet = '' with open(self._alphabet, mode='rb') as f: for line in f.readlines(): alphabet += line.decode('utf-8')[0] self._alphabet = alphabet self._converter = CTCLabelConverter(self._alphabet, ignore_case=False) num_classes = len(self._alphabet) + 1 num_channels = 1 self._model = DenseNetRNN( num_channels=num_channels, num_classes=num_classes, rnn=True, num_hidden=256, growth_rate=12, block_config=(3, 6, 9), #(3,6,12,16), compression=0.5, num_init_features=64, bn_size=4, drop_rate=0, small_inputs=True, efficient=False) for param in self._model.parameters(): param.requires_grad = False state_dict = torch.load(model_path, map_location=device) self._model.load_state_dict(state_dict['state_dict']) self.mode = self._model.to(self._device) self._model.eval() def predict(self, image): assert image.ndim == 2 width, height = image.shape[1], image.shape[0] scale = self._image_size[0] / height width = int(round(width * scale)) width = min(width, self._image_size[1]) image = cv2.resize(np.array(image), (width, self._image_size[0]), interpolation=cv2.INTER_CUBIC) # random padding left = random.randint(0, self._image_size[1] - width) right = self._image_size[1] - width - left image = cv2.copyMakeBorder(image, 0, 0, left, right, cv2.BORDER_REPLICATE) image = self._transform(image) image = torch.unsqueeze(image, 0).to(self._device) # B, C, H, W with torch.set_grad_enabled(False): _, pred = self._model(image) score, pred_index = pred.max(2) score = score.contiguous().view(-1) pred_index = pred_index.contiguous().view(-1) score = score.data.cpu() pred_index = pred_index.data.cpu() text, valid_index = self._converter.decode(pred_index.data, torch.IntTensor( [pred_index.size(0)]), raw=False) return text
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model) if torch.cuda.is_available(): model = model.cuda() # load model print('loading pretrained model from %s' % opt.saved_model) if torch.cuda.is_available(): model.load_state_dict(torch.load(opt.saved_model)) else: model.load_state_dict(torch.load(opt.saved_model, map_location="cpu")) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True) # predict model.eval() for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) with torch.no_grad(): if torch.cuda.is_available(): image = image_tensors.cuda() # For max length prediction length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.cuda.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) else: image = image_tensors # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size) text_for_pred = torch.LongTensor( batch_size, opt.batch_max_length + 1).fill_(0) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred).log_softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = converter.decode(preds_index.data, preds_size.data) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) print('-' * 80) print('image_path\tpredicted_labels') print('-' * 80) for img_name, pred in zip(image_path_list, preds_str): if 'Attn' in opt.Prediction: pred = pred[:pred.find( '[s]')] # prune after "end of sentence" token ([s]) print(f'{img_name}\t{pred}')
class OcrRec: def __init__(self, opt=None): self.max_length = 25 self.opt = ConfigOpt() if opt: self.opt = opt self.batch_size = 1 self.model = None self.cur_path = os.path.abspath(os.path.dirname(__file__)) self.opt.saved_model = os.path.join(self.cur_path, "models/mtl_best_accuracy.pth") self.opt.Transformation = 'None' # None|TPS self.opt.FeatureExtraction = 'ResNet' # VGG|RCNN|ResNet self.opt.SequenceModeling = 'BiLSTM' # None|BiLSTM self.opt.Prediction = 'CTC' # CTC|Attn (use CTC or Attention in inference stage) # self.opt.output_channel = 512 # self.opt.hidden_size = 256 self.opt.output_channel = 768 self.opt.hidden_size = 384 self.opt.mtl = True self.ctc_converter = None self.attn_converter = None self.load_model() def load_model(self): if 'CTC' in self.opt.Prediction: self.ctc_converter = CTCLabelConverter(self.opt.character) self.opt.ctc_num_class = len(self.ctc_converter.character) self.opt.num_class = self.opt.ctc_num_class + 1 else: self.attn_converter = AttnLabelConverter(self.opt.character) self.opt.num_class = len(self.attn_converter.character) self.opt.ctc_num_class = self.opt.num_class - 1 if self.opt.rgb: self.opt.input_channel = 3 self.model = Model(self.opt) print('model input parameters', self.opt.imgH, self.opt.imgW, self.opt.num_fiducial, self.opt.input_channel, self.opt.output_channel, self.opt.hidden_size, self.opt.num_class, self.opt.batch_max_length, self.opt.Transformation, self.opt.FeatureExtraction, self.opt.SequenceModeling, self.opt.Prediction) print(f"=====Use {self.opt.Prediction} prediction result=====") self.model = torch.nn.DataParallel(self.model) if torch.cuda.is_available(): self.model = self.model.cuda() # load model print('loading pretrained model from %s' % self.opt.saved_model) if torch.cuda.is_available(): self.model.load_state_dict(torch.load(self.opt.saved_model)) else: self.model.load_state_dict( torch.load(self.opt.saved_model, map_location="cpu")) self.model.eval() def text_rec(self, img): """ resize PIL image to fixed height, keep width/height ratio do inference :param img: :return: """ if isinstance(img, str) and os.path.isfile(img): img = Image.open(img) img = img.convert('L') import PIL.ImageOps # img = PIL.ImageOps.invert(img) if not img.mode == 'L': img = img.convert('L') ratio = self.opt.imgH / img.size[1] target_w = int(img.size[0] * ratio) transformer = InferResizeNormalize((target_w, self.opt.imgH)) img = transformer(img) img = img.view(1, *img.size()) img = Variable(img) with torch.no_grad(): if torch.cuda.is_available(): img = img.cuda() length_for_pred = torch.cuda.IntTensor( [self.opt.batch_max_length] * self.batch_size) text_for_pred = torch.cuda.LongTensor( self.batch_size, self.opt.batch_max_length + 1).fill_(0) else: length_for_pred = torch.IntTensor([self.opt.batch_max_length] * self.batch_size) text_for_pred = torch.LongTensor( self.batch_size, self.opt.batch_max_length + 1).fill_(0) if 'CTC' in self.opt.Prediction: preds, _ = self.model(img, text_for_pred, is_train=False) preds = preds.softmax(2) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * self.batch_size) preds_prob_vals, preds_index = preds.permute(1, 0, 2).max(2) preds_index = preds_index.transpose(1, 0).contiguous().view(-1) preds_str = self.ctc_converter.decode(preds_index.data, preds_size.data) elif 'Attn' in self.opt.Prediction: _, preds = self.model(img, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = self.attn_converter.decode(preds_index, length_for_pred) preds_str = [pred[:pred.find('[s]')] for pred in preds_str] return preds_str[0]