def dataset_test(lmdb_path, batch_size): dataset = LmdbDataset(lmdb_path) dataloader = DataLoader(dataset, batch_size, shuffle=False, num_workers=0) for i, data in enumerate(dataloader): img, label = data print(i, img, label) print(i, img.shape, label.shape)
def setUp(self): normalizeFunc = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transf = transforms.Compose([ transforms.ToPILImage(), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalizeFunc ]) self.dataset = LmdbDataset("val.lmdb", transform=transf) self.dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True, num_workers=4)
def main(opt): # 'None' corresponds to the clean data transforms = [None] # Make tests reproducible rng = np.random.default_rng(opt.seed) corruptions = [ Curve(rng=rng), Distort(rng), Stretch(rng), Rotate(rng=rng), Perspective(rng), Shrink(rng), TranslateX(rng), TranslateY(rng), VGrid(rng), HGrid(rng), Grid(rng), RectGrid(rng), EllipseGrid(rng), GaussianNoise(rng), ShotNoise(rng), ImpulseNoise(rng), SpeckleNoise(rng), GaussianBlur(rng), DefocusBlur(rng), MotionBlur(rng), GlassBlur(rng), ZoomBlur(rng), Contrast(rng), Brightness(rng), JpegCompression(rng), Pixelate(rng), Fog(rng), Snow(rng), Frost(rng), Rain(rng), Shadow(rng), Posterize(rng), Solarize(rng), Invert(rng), Equalize(rng), AutoContrast(rng), Sharpness(rng), Color(rng) ] # Generate partial functions for the three severity levels for c in corruptions: for level in range(1): p = partial(c, mag=level) p.__name__ = '{}-{}'.format(c.__class__.__name__, level) transforms.append(p) for tr in transforms: name = 'Clean' if tr is None else tr.__name__ for d in os.listdir(opt.eval_data): outdir = os.path.join('corrupted-data', name, d) os.makedirs(outdir) for i, (img, label) in enumerate( LmdbDataset(os.path.join(opt.eval_data, d), opt, tr)): print(outdir, i) #img = img.resize((224, 224)) img = img.resize((100, 32)) if tr is not None: img = tr(img) #img = img.resize((100, 32)) img.save(os.path.join(outdir, '{:04d}.png'.format(i)))
model = model.to(config.device) if config.device == 'cuda': model = torch.nn.DataParallel(model) model.eval() test_data_dir = "../data/data_lmdb_release/evaluation/" test_data_set = [ "IIIT5k_3000", "SVT", "IC03_867", "IC13_1015", "IC15_1811", "SVTP", "CUTE80" ] device = config.device for test_data in test_data_set: path = test_data_dir + test_data test_dataset = LmdbDataset(path, config.lmdb_config) data_loader = DataLoader( test_dataset, batch_size=config.batch_size, num_workers=4, shuffle=False, pin_memory=True, drop_last=False, ) test_data += '(%d)' % (len(test_dataset)) targets = [] pred_rec = [] for i, data_in in enumerate(data_loader): if test_dataset.use_bidecoder: imgs, labels1, labels2, lengths = data_in
def train(): """ dataset preparation """ train_dataset_lmdb = LmdbDataset(cfg.lmdb_trainset_dir_name) val_dataset_lmdb = LmdbDataset(cfg.lmdb_valset_dir_name) train_loader = torch.utils.data.DataLoader( train_dataset_lmdb, batch_size=cfg.batch_size, collate_fn=data_collate, shuffle=True, num_workers=int(cfg.workers), pin_memory=True) valid_loader = torch.utils.data.DataLoader( val_dataset_lmdb, batch_size=cfg.batch_size, collate_fn=data_collate, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(cfg.workers), pin_memory=True) # --------------------训练过程--------------------------------- model = advancedEAST() if int(cfg.train_task_id[-3:]) != 256: id_num = cfg.train_task_id[-3:] idx_dic = {'384': 256, '512': 384, '640': 512, '736': 640} model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(idx_dic[id_num]))) elif os.path.exists('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)): model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id))) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay) loss_func = quad_loss train_Loss_list = [] val_Loss_list = [] '''start training''' start_iter = 0 if cfg.saved_model != '': try: start_iter = int(cfg.saved_model.split('_')[-1].split('.')[0]) print('continue to train, start_iter: {}'.format(start_iter)) except Exception as e: print(e) pass start_time = time.time() best_mF1_score = 0 i = start_iter step_num = 0 start_time = time.time() loss_avg = Averager() val_loss_avg = Averager() eval_p_r_f = eval_pre_rec_f1() while(True): model.train() # train part # training----------------------------- for image_tensors, labels, gt_xy_list in train_loader: step_num += 1 batch_x = image_tensors.to(device).float() batch_y = labels.to(device).float() # float64转float32 out = model(batch_x) loss = loss_func(batch_y, out) optimizer.zero_grad() loss.backward() optimizer.step() loss_avg.add(loss) train_Loss_list.append(loss_avg.val()) if i == 5 or (i + 1) % 10 == 0: eval_p_r_f.add(out, gt_xy_list) # 非常耗时!!! # save model per 100 epochs. if (i + 1) % 1e+2 == 0: torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, step_num+1)) print('Epoch:[{}/{}] Training Loss: {:.3f}'.format(i + 1, cfg.epoch_num, train_Loss_list[-1].item())) loss_avg.reset() if i == 5 or (i + 1) % 10 == 0: mPre, mRec, mF1_score = eval_p_r_f.val() print('Training meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score)) eval_p_r_f.reset() # evaluation-------------------------------- if (i + 1) % cfg.valInterval == 0: elapsed_time = time.time() - start_time print('Elapsed time:{}s'.format(round(elapsed_time))) model.eval() for image_tensors, labels, gt_xy_list in valid_loader: batch_x = image_tensors.to(device) batch_y = labels.to(device).float() # float64转float32 out = model(batch_x) loss = loss_func(batch_y, out) val_loss_avg.add(loss) val_Loss_list.append(val_loss_avg.val()) eval_p_r_f.add(out, gt_xy_list) mPre, mRec, mF1_score = eval_p_r_f.val() print('validation meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score)) eval_p_r_f.reset() if mF1_score > best_mF1_score: # 记录最佳模型 best_mF1_score = mF1_score torch.save(model.state_dict(), './saved_models/{}/{}_best_mF1_score_{:.3f}.pth'.format(cfg.train_task_id, cfg.train_task_id, mF1_score)) torch.save(model.state_dict(), './saved_model/{}_best_mF1_score.pth'.format(cfg.train_task_id)) print('Validation loss:{:.3f}'.format(val_loss_avg.val().item())) val_loss_avg.reset() if i == cfg.epoch_num: torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, i+1)) print('End the training') break i += 1 sys.exit()
def train(field): alphabet = ''.join(json.load(open('./cn-alphabet.json', 'rb'))) nclass = len(alphabet) + 1 # add the dash - batch_size = BATCH_SIZE if field == 'address' or field == 'psb': batch_size = 1 # image length varies converter = LabelConverter(alphabet) criterion = CTCLoss(zero_infinity=True) crnn = CRNN(IMAGE_HEIGHT, nc, nclass, number_hidden) crnn.apply(weights_init) image_transform = transforms.Compose([ Rescale(IMAGE_HEIGHT), transforms.ToTensor(), Normalize() ]) dataset = LmdbDataset(db_path, field, image_transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) image = torch.FloatTensor(batch_size, 3, IMAGE_HEIGHT, IMAGE_HEIGHT) text = torch.IntTensor(batch_size * 5) length = torch.IntTensor(batch_size) image = Variable(image) text = Variable(text) length = Variable(length) loss_avg = utils.averager() optimizer = optim.RMSprop(crnn.parameters(), lr=lr) if torch.cuda.is_available(): crnn.cuda() crnn = nn.DataParallel(crnn) image = image.cuda() criterion = criterion.cuda() def train_batch(net, iteration): data = iteration.next() cpu_images, cpu_texts = data batch_size = cpu_images.size(0) utils.load_data(image, cpu_images) t, l = converter.encode(cpu_texts) utils.load_data(text, t) utils.load_data(length, l) preds = crnn(image) preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) cost = criterion(preds, text, preds_size, length) / batch_size crnn.zero_grad() cost.backward() optimizer.step() return cost nepoch = 25 for epoch in range(nepoch): train_iter = iter(dataloader) i = 0 while i < len(dataloader): for p in crnn.parameters(): p.requires_grad = True crnn.train() cost = train_batch(crnn, train_iter) loss_avg.add(cost) i += 1 if i % 500 == 0: print('%s [%d/%d][%d/%d] Loss: %f' % (datetime.datetime.now(), epoch, nepoch, i, len(dataloader), loss_avg.val())) loss_avg.reset() # do checkpointing if i % 500 == 0: torch.save( crnn.state_dict(), f'{model_path}crnn_{field}_{epoch}_{i}.pth')
config = TainTestConfig() # TIMESTAMP = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') writer = SummaryWriter( log_dir="../data/logs/" + config.name, flush_secs=30, ) # config.lmdb_config.num_samples = 10000 # config.batch_size = 1024 n_device = torch.cuda.device_count() config.batch_size = 256 * n_device config.iter_to_valid = 128 * 8 train_dataset = torch.utils.data.ConcatDataset( [LmdbDataset(path, config.lmdb_config) for path in config.train_data]) data_loader = DataLoader( train_dataset, batch_size=config.batch_size, num_workers=2 * n_device, #4, shuffle=True, pin_memory=True, drop_last=True, ) path = "../data/lmdbs/evaluation/IIIT5K_3000" config.batch_size = 1024 config.lmdb_config.num_samples = 1000 test_dataset = LmdbDataset(path, config.lmdb_config)
def demo(opt): """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) model = torch.nn.DataParallel(model).to(device) # load model print('loading pretrained model from %s' % opt.saved_model) model.load_state_dict(torch.load(opt.saved_model, map_location=device)) # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device))) # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) # demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset demo_data = LmdbDataset(root=opt.image_folder, opt=opt, mode='Val') # use RawDataset demo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), collate_fn=AlignCollate_demo, pin_memory=True, drop_last=True) log = open(f'./log_demo_result.txt', 'a') # predict model.eval() fail_count, sample_count = 0, 0 record_count = 1 with torch.no_grad(): for image_tensors, image_path_list in demo_loader: batch_size = image_tensors.size(0) image = image_tensors.to(device) # For max length prediction length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) if 'CTC' in opt.Prediction: preds = model(image, text_for_pred) # Select max probabilty (greedy decoding) then decode index to character preds_size = torch.IntTensor([preds.size(1)] * batch_size) _, preds_index = preds.max(2) # preds_index = preds_index.view(-1) preds_str = converter.decode(preds_index, preds_size) else: preds = model(image, text_for_pred, is_train=False) # select max probabilty (greedy decoding) then decode index to character _, preds_index = preds.max(2) preds_str = converter.decode(preds_index, length_for_pred) dashed_line = '-' * 80 head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' print(f'{dashed_line}\n{head}\n{dashed_line}') log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') preds_prob = F.softmax(preds, dim=2) preds_max_prob, _ = preds_prob.max(dim=2) for image_tensor, gt, pred, pred_max_prob in zip( image_tensors, image_path_list, preds_str, preds_max_prob): if 'Attn' in opt.Prediction: pred_EOS = pred.find('[s]') pred = pred[: pred_EOS] # prune after "end of sentence" token ([s]) pred_max_prob = pred_max_prob[:pred_EOS] if pred_max_prob.shape[0] > 0: # calculate confidence score (= multiply of pred_max_prob) confidence_score = pred_max_prob.cumprod(dim=0)[-1] else: confidence_score = 0.0 # gt = img_name.split('_L_')[1] # gt = gt.split('.')[0] # pred = pred.split('.')[0] # except IndexError: # print(f'Index Error {img_name}') # raise IndexError # if img_name.find('1_225427_L_대전출입국관리사무소_L_21.png') >=0 : # print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') # if gt.split('(')[0] != pred.split('(')[0]: # fail_count += 1 # log.write(f'{gt:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') # # import shutil # shutil.copy(gt, os.path.join('./result', os.path.basename(img_name))) # if gt.find('#') >= 0: # continue compare_gt = "".join(x.upper() for x in gt if x.isalnum()) compare_pred = "".join(x.upper() for x in pred if x.isalnum()) if compare_gt != compare_pred: fail_count += 1 print( f'{gt:25s}\t{pred:25s}\tFail\t{confidence_score:0.4f}\t{record_count}\n' ) im = to_pil_image(image_tensor) try: im.save( os.path.join( 'result', f'{fail_count}_{compare_pred}_{compare_gt}.jpeg' )) except Exception as e: print( f'Error: {e} {fail_count}_{compare_pred}_{compare_gt}' ) exit(1) else: # print(f'{gt:25s}\t{pred:25s}\tSuccess\t{confidence_score:0.4f}') pass sample_count += 1 record_count += 1 log.close() print(f'total accuracy: {(sample_count-fail_count)/sample_count:.2f}')
def setUp(self): self.dataset = LmdbDataset("train_lmdb") self.dataloader = DataLoader(self.dataset, batch_size=32, shuffle=True, num_workers=4)
def train(opt): """ 准备训练和验证的数据集 """ transform = transforms.Compose([ ToTensor(), ]) train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) valid_dataset = LmdbDataset(root=opt.valid_data, opt=opt, transform=transform) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers), ) print('-' * 80) """ 模型的配置 """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # 权重初始化 for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue model = model.to(device) model.train() if opt.continue_model != '': print(f'loading pretrained model from {opt.continue_model}') model.load_state_dict(torch.load(opt.continue_model)) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.continue_model != '': start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') start_time = time.time() best_accuracy = -1 best_norm_ED = 1e+6 i = start_iter while (True): # train part for image_tensors, labels in train_loader: image = image_tensors.to(device) text, length = converter.encode( labels, batch_max_length=opt.batch_max_length ) # text: [index, index, ..., index], length: [10, 8] batch_size = image.size(0) if 'CTC' in opt.Prediction: # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1) preds = model(image, text).log_softmax(2) # torch.Size([100, 63, 12]) preds_size = torch.IntTensor([preds.size(1)] * batch_size).to(device) preds = preds.permute( 1, 0, 2 ) # to use CTCLoss format # 100 * 63 * 7 -> 63 * 100 * 7 # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion( preds, text, preds_size, length ) # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符 # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符 torch.backends.cudnn.enabled = True else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_( model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if i % opt.valInterval == 0: elapsed_time = time.time() - start_time print( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}' ) # for log with open( f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: log.write( f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n' ) loss_avg.reset() model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() for pred, gt in zip(preds[:5], labels[:5]): if 'Attn' in opt.Prediction: pred = pred[:pred.find('[s]')] gt = gt[:gt.find('[s]')] print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') log.write( f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' print(valid_log) log.write(valid_log + '\n') # keep best accuracy model if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth' ) if current_norm_ED < best_norm_ED: best_norm_ED = current_norm_ED torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' ) best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' print(best_model_log) log.write(best_model_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') if i == opt.num_iter: print('end the training') sys.exit() i += 1
def test(opt): lib.print_model_settings(locals().copy()) if 'Attn' in opt.Prediction: converter = AttnLabelConverter(opt.character) text_len = opt.batch_max_length + 2 else: converter = CTCLabelConverter(opt.character) text_len = opt.batch_max_length opt.classes = converter.character """ dataset preparation """ if not opt.data_filtering_off: print( 'Filtering the images containing characters which are not in opt.character' ) print( 'Filtering the images whose label is longer than opt.batch_max_length' ) # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset = LmdbDataset(root=opt.test_data, opt=opt) test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle= False, # 'True' to check training progress with validation function. sampler=test_data_sampler, num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False) print('-' * 80) opt.num_class = len(converter.character) ocrModel = ModelV1(opt).to(device) ## Loading pre-trained files print(f'loading pretrained ocr model from {opt.saved_ocr_model}') checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage) ocrModel.load_state_dict(checkpoint) evalCntr = 0 fCntr = 0 c1_s1_input_correct = 0.0 c1_s1_input_ed_correct = 0.0 # pdb.set_trace() for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader): print(vCntr) image_input_tensors = image_input_tensors.to(device) text_gt, length_gt = converter.encode( labels_gt, batch_max_length=opt.batch_max_length) with torch.no_grad(): currBatchSize = image_input_tensors.shape[0] # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device) length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device) #Run OCR prediction if 'CTC' in opt.Prediction: preds = ocrModel(image_input_tensors, text_gt, is_train=False) preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0]) _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data) else: preds = ocrModel( image_input_tensors, text_gt[:, :-1], is_train=False) # align with Attention.forward _, preds_index = preds.max(2) preds_str_gt_1 = converter.decode(preds_index, length_for_pred) for idx, pred in enumerate(preds_str_gt_1): pred_EOS = pred.find('[s]') preds_str_gt_1[ idx] = pred[: pred_EOS] # prune after "end of sentence" token ([s]) for trImgCntr in range(image_input_tensors.shape[0]): #ocr accuracy # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): c1_s1_input_gt = labels_gt[trImgCntr] c1_s1_input_ocr = preds_str_gt_1[trImgCntr] if c1_s1_input_gt == c1_s1_input_ocr: c1_s1_input_correct += 1 # ICDAR2019 Normalized Edit Distance if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0: c1_s1_input_ed_correct += 0 elif len(c1_s1_input_gt) > len(c1_s1_input_ocr): c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt) else: c1_s1_input_ed_correct += 1 - edit_distance( c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr) evalCntr += 1 fCntr += 1 avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr) avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr) # if not(opt.realVaData): with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'), 'a') as log: # training loss and validation loss loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}' print(loss_log) log.write(loss_log + "\n")
type=int, default=1, help='the number of input channel of Feature extractor') parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor') parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') parser.add_argument('--include_space', type=bool, default=False) opt = parser.parse_args() opt.character = [] with open(os.path.join(opt.train_data, 'kr_labels.txt'), 'r') as f: lines = f.readlines() for line in lines: ch = line.strip().split()[1] if len(ch) != 1: print(f'{ch}s length is greater than 1') opt.character.append(ch) if opt.include_space: opt.character.append(' ') dataset = LmdbDataset('data/train/printed', opt) for i in range(1, 1000): dataset.__getitem__(i)