def __init__(self, data_dir, gt_mat_path, max_len=25): super(SynthTextLoader, self).__init__() self.data_dir = data_dir self.mat_contents = sio.loadmat(gt_mat_path) self.images_name = self.mat_contents['imnames'][0] self.num_samples = len(self.images_name) self.max_len = max_len self.voc, self.char2id, _ = get_vocabulary("ALLCASES_SYMBOLS")
def __init__(self, data_dir, gt_dir, with_script=False, shuffle=False, max_len=25): super(ICDAR15Loader, self).__init__() self.data_dir = data_dir self.images_path = self.get_images() self.num_samples = len(self.images_path) self.gt_dir = gt_dir self.max_len = max_len self.with_script = with_script self.shuffle = shuffle # shuffle the polygons self.voc, self.char2id, _ = get_vocabulary("ALLCASES_SYMBOLS")
def test(args, cpks): assert isinstance(cpks, str) voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS") test_data = build_dataloader(args.val_data_cfg) print("test data: {}".format(len(test_data))) model = MultiInstanceRecognition(args.model_cfg).cuda() # model = MMDataParallel(model).cuda() model.load_state_dict(torch.load(cpks)) model.eval() pred_strs = [] gt_strs = [] test_data_iter = iter(test_data) for i, batch_data in enumerate(test_data): torch.cuda.empty_cache() batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data if batch_imgs is None: continue batch_imgs = batch_imgs.cuda() batch_rectangles = batch_rectangles.cuda() batch_text_labels = batch_text_labels.cuda() with torch.no_grad(): loss, decoder_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) pred_labels = decoder_logits.argmax(dim=2).cpu().numpy() pred_value_str = idx2label(pred_labels, id2char, char2id) gt_str = batch_words for i in range(len(gt_str[0])): print("predict: {} label: {}".format(pred_value_str[i], gt_str[0][i])) pred_strs.append(pred_value_str[i]) gt_strs.append(gt_str[0][i]) val_dec_metrics_result = calc_metrics(pred_strs, gt_strs, metrics_type="accuracy") print("test accuracy= {:3f}".format(val_dec_metrics_result)) # # # val_loss_value)) print('---------')
def idx2label(inputs, id2char=None, char2id=None): if id2char is None: voc, char2id, id2char = get_vocabulary(voc_type="ALLCASES_SYMBOLS") def end_cut(ins): cut_ins = [] for id in ins: if id != char2id['EOS']: if id != char2id['UNK']: cut_ins.append(id2char[id]) else: break return cut_ins if isinstance(inputs, np.ndarray): assert len(inputs.shape) == 2, "input's rank should be 2" results = [''.join([ch for ch in end_cut(ins)]) for ins in inputs] return results else: print("input to idx2label should be numpy array") return inputs
data_type='ICDAR13', num_instances=4, crop_ratio=0.0, crop_random=False, input_width=640, input_height=1280, keep_ratio=True, max_len=max_len, batch_size=1, num_works=1, shuffle=False) from data_tools.data_utils import get_vocabulary dim = 512 seq_len = 25 voc_len = len(get_vocabulary("ALLCASES_SYMBOLS")[0]) model_cfg = dict( dim=512, seq_len=25, voc_len=voc_len, num_instances=4, roi_size=(16, 64), feature_channels=[256, 512, 1024, 2048], fpn_out_channels=128, roi_feature_step=4, encoder_channels=512, embedding=dict( dim=dim, voc_len=voc_len, embedding_dim=512, pos_dim=seq_len,
def train(cfg, args): logger = logging.getLogger('model training') train_data = build_dataloader(cfg.train_data_cfg, args.distributed) logger.info("train data: {}".format(len(train_data))) val_data = build_dataloader(cfg.val_data_cfg, args.distributed) logger.info("val data: {}".format(len(val_data))) model = MultiInstanceRecognition(cfg.model_cfg).cuda() if cfg.resume_from is not None: logger.info('loading pretrained models from {opt.continue_model}') model.load_state_dict(torch.load(cfg.resume_from)) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS") filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) logger.info('Trainable params num : ', sum(params_num)) optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999)) lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1) max_iters = cfg.max_iters start_iter = 0 if cfg.resume_from is not None: start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0]) logger.info('continue to train, start_iter: {start_iter}') train_data_iter = iter(train_data) val_data_iter = iter(val_data) start_time = time.time() for i in range(start_iter, max_iters): model.train() try: batch_data = next(train_data_iter) except StopIteration: train_data_iter = iter(train_data) batch_data = next(train_data_iter) data_time_s = time.time() batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data while batch_imgs is None: batch_data = next(train_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ batch_data batch_imgs = batch_imgs.cuda(non_blocking=True) batch_rectangles = batch_rectangles.cuda(non_blocking=True) batch_text_labels = batch_text_labels.cuda(non_blocking=True) data_time = time.time() - data_time_s # print(time.time() -s) # s = time.time() loss, decoder_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) del batch_data # print(time.time() - s) # print('------') # s = time.time() loss = loss.mean() print(loss) # del loss # print(time.time() - s) # print('------') if i % cfg.train_verbose == 0: this_time = time.time() - start_time if args.distributed: loss = dist.reduce(loss, 0) log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format( i, this_time, data_time, loss.data) logger.info(log_info) torch.cuda.empty_cache() # break optimizer.zero_grad() loss.backward() optimizer.step() del loss if i % cfg.val_iter == 0: print("--------Val iteration---------") model.eval() try: val_batch = next(val_data_iter) except StopIteration: val_data_iter = iter(val_data) val_batch = next(val_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ val_batch while batch_imgs is None: val_batch = next(val_data_iter) batch_imgs, batch_imgs_path, batch_rectangles, \ batch_text_labels, batch_text_labels_mask, batch_words = \ val_batch del val_batch batch_imgs = batch_imgs.cuda(non_blocking=True) batch_rectangles = batch_rectangles.cuda(non_blocking=True) batch_text_labels = batch_text_labels.cuda(non_blocking=True) with torch.no_grad(): val_loss, val_pred_logits = model(batch_imgs, batch_text_labels, batch_rectangles, batch_text_labels_mask) pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy() pred_value_str = idx2label(pred_labels, id2char, char2id) # gt_str = batch_words gt_str = [] for words in batch_words: gt_str = gt_str + words val_dec_metrics_result = calc_metrics(pred_value_str, gt_str, metrics_type="accuracy") this_time = time.time() - start_time if args.distributed: loss = dist.reduce(val_loss, 0) log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format( i, this_time, loss.mean().data, val_dec_metrics_result) logger.info(log_info) del val_loss if (i + 1) % cfg.save_iter == 0: torch.save(model.state_dict(), cfg.save_name + '_{}.pth'.format(i + 1)) if i > 0 and i % cfg.lr_step == 0: # 调整学习速率 lrScheduler.step() logger.info("lr step") # torch.cuda.empty_cache() print('end the training')