コード例 #1
0
	def __init__(self, data_dir, gt_mat_path, max_len=25):
		super(SynthTextLoader, self).__init__()
		self.data_dir = data_dir
		self.mat_contents = sio.loadmat(gt_mat_path)
		self.images_name = self.mat_contents['imnames'][0]
		self.num_samples = len(self.images_name)
		self.max_len = max_len
		self.voc, self.char2id, _ = get_vocabulary("ALLCASES_SYMBOLS")
コード例 #2
0
	def __init__(self, data_dir, gt_dir, with_script=False, shuffle=False, max_len=25):
		super(ICDAR15Loader, self).__init__()
		self.data_dir = data_dir
		self.images_path = self.get_images()
		self.num_samples = len(self.images_path)
		self.gt_dir = gt_dir
		self.max_len = max_len
		self.with_script = with_script
		self.shuffle = shuffle # shuffle the polygons
		self.voc, self.char2id, _ = get_vocabulary("ALLCASES_SYMBOLS")
コード例 #3
0
def test(args, cpks):

    assert isinstance(cpks, str)

    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    test_data = build_dataloader(args.val_data_cfg)
    print("test data: {}".format(len(test_data)))

    model = MultiInstanceRecognition(args.model_cfg).cuda()
    # model = MMDataParallel(model).cuda()
    model.load_state_dict(torch.load(cpks))
    model.eval()
    pred_strs = []
    gt_strs = []
    test_data_iter = iter(test_data)
    for i, batch_data in enumerate(test_data):
        torch.cuda.empty_cache()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        if batch_imgs is None:
            continue
        batch_imgs = batch_imgs.cuda()
        batch_rectangles = batch_rectangles.cuda()
        batch_text_labels = batch_text_labels.cuda()
        with torch.no_grad():
            loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                         batch_rectangles,
                                         batch_text_labels_mask)

        pred_labels = decoder_logits.argmax(dim=2).cpu().numpy()
        pred_value_str = idx2label(pred_labels, id2char, char2id)
        gt_str = batch_words

        for i in range(len(gt_str[0])):
            print("predict: {} label: {}".format(pred_value_str[i],
                                                 gt_str[0][i]))
            pred_strs.append(pred_value_str[i])
            gt_strs.append(gt_str[0][i])

        val_dec_metrics_result = calc_metrics(pred_strs,
                                              gt_strs,
                                              metrics_type="accuracy")

        print("test accuracy= {:3f}".format(val_dec_metrics_result))
        #
        #
        #                                                                         val_loss_value))
        print('---------')
コード例 #4
0
def idx2label(inputs, id2char=None, char2id=None):

    if id2char is None:
        voc, char2id, id2char = get_vocabulary(voc_type="ALLCASES_SYMBOLS")

    def end_cut(ins):
        cut_ins = []
        for id in ins:
            if id != char2id['EOS']:
                if id != char2id['UNK']:
                    cut_ins.append(id2char[id])
            else:
                break
        return cut_ins

    if isinstance(inputs, np.ndarray):
        assert len(inputs.shape) == 2, "input's rank should be 2"
        results = [''.join([ch for ch in end_cut(ins)]) for ins in inputs]
        return results
    else:
        print("input to idx2label should be numpy array")
        return inputs
コード例 #5
0
                    data_type='ICDAR13',
                    num_instances=4,
                    crop_ratio=0.0,
                    crop_random=False,
                    input_width=640,
                    input_height=1280,
                    keep_ratio=True,
                    max_len=max_len,
                    batch_size=1,
                    num_works=1,
                    shuffle=False)
from data_tools.data_utils import get_vocabulary

dim = 512
seq_len = 25
voc_len = len(get_vocabulary("ALLCASES_SYMBOLS")[0])
model_cfg = dict(
    dim=512,
    seq_len=25,
    voc_len=voc_len,
    num_instances=4,
    roi_size=(16, 64),
    feature_channels=[256, 512, 1024, 2048],
    fpn_out_channels=128,
    roi_feature_step=4,
    encoder_channels=512,
    embedding=dict(
        dim=dim,
        voc_len=voc_len,
        embedding_dim=512,
        pos_dim=seq_len,
コード例 #6
0
def train(cfg, args):
    logger = logging.getLogger('model training')
    train_data = build_dataloader(cfg.train_data_cfg, args.distributed)
    logger.info("train data: {}".format(len(train_data)))
    val_data = build_dataloader(cfg.val_data_cfg, args.distributed)
    logger.info("val data: {}".format(len(val_data)))

    model = MultiInstanceRecognition(cfg.model_cfg).cuda()
    if cfg.resume_from is not None:
        logger.info('loading pretrained models from {opt.continue_model}')
        model.load_state_dict(torch.load(cfg.resume_from))
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
    voc, char2id, id2char = get_vocabulary("ALLCASES_SYMBOLS")

    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    logger.info('Trainable params num : ', sum(params_num))
    optimizer = optim.Adam(filtered_parameters, lr=cfg.lr, betas=(0.9, 0.999))
    lrScheduler = lr_scheduler.MultiStepLR(optimizer, [1, 2, 3], gamma=0.1)

    max_iters = cfg.max_iters
    start_iter = 0
    if cfg.resume_from is not None:
        start_iter = int(cfg.resume_from.split('_')[-1].split('.')[0])
        logger.info('continue to train, start_iter: {start_iter}')

    train_data_iter = iter(train_data)
    val_data_iter = iter(val_data)
    start_time = time.time()
    for i in range(start_iter, max_iters):
        model.train()
        try:
            batch_data = next(train_data_iter)
        except StopIteration:
            train_data_iter = iter(train_data)
            batch_data = next(train_data_iter)
        data_time_s = time.time()
        batch_imgs, batch_imgs_path, batch_rectangles, \
        batch_text_labels, batch_text_labels_mask, batch_words = \
            batch_data
        while batch_imgs is None:
            batch_data = next(train_data_iter)
            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                batch_data

        batch_imgs = batch_imgs.cuda(non_blocking=True)
        batch_rectangles = batch_rectangles.cuda(non_blocking=True)
        batch_text_labels = batch_text_labels.cuda(non_blocking=True)
        data_time = time.time() - data_time_s
        # print(time.time() -s)
        # s = time.time()
        loss, decoder_logits = model(batch_imgs, batch_text_labels,
                                     batch_rectangles, batch_text_labels_mask)
        del batch_data
        # print(time.time() - s)
        # print('------')
        # s = time.time()

        loss = loss.mean()
        print(loss)
        # del loss
        # print(time.time() - s)
        # print('------')

        if i % cfg.train_verbose == 0:
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(loss, 0)
            log_info = "train iter :{}, time: {:.2f}, data_time: {:.2f}, Loss: {:.3f}".format(
                i, this_time, data_time, loss.data)
            logger.info(log_info)
            torch.cuda.empty_cache()
            # break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del loss
        if i % cfg.val_iter == 0:
            print("--------Val iteration---------")
            model.eval()

            try:
                val_batch = next(val_data_iter)
            except StopIteration:
                val_data_iter = iter(val_data)
                val_batch = next(val_data_iter)

            batch_imgs, batch_imgs_path, batch_rectangles, \
            batch_text_labels, batch_text_labels_mask, batch_words = \
                val_batch
            while batch_imgs is None:
                val_batch = next(val_data_iter)
                batch_imgs, batch_imgs_path, batch_rectangles, \
                batch_text_labels, batch_text_labels_mask, batch_words = \
                    val_batch
            del val_batch
            batch_imgs = batch_imgs.cuda(non_blocking=True)
            batch_rectangles = batch_rectangles.cuda(non_blocking=True)
            batch_text_labels = batch_text_labels.cuda(non_blocking=True)
            with torch.no_grad():
                val_loss, val_pred_logits = model(batch_imgs,
                                                  batch_text_labels,
                                                  batch_rectangles,
                                                  batch_text_labels_mask)
            pred_labels = val_pred_logits.argmax(dim=2).cpu().numpy()
            pred_value_str = idx2label(pred_labels, id2char, char2id)
            # gt_str = batch_words
            gt_str = []
            for words in batch_words:
                gt_str = gt_str + words
            val_dec_metrics_result = calc_metrics(pred_value_str,
                                                  gt_str,
                                                  metrics_type="accuracy")
            this_time = time.time() - start_time
            if args.distributed:
                loss = dist.reduce(val_loss, 0)
            log_info = "val iter :{}, time: {:.2f} Loss: {:.3f}, acc: {:.2f}".format(
                i, this_time,
                loss.mean().data, val_dec_metrics_result)
            logger.info(log_info)
            del val_loss
        if (i + 1) % cfg.save_iter == 0:
            torch.save(model.state_dict(),
                       cfg.save_name + '_{}.pth'.format(i + 1))
        if i > 0 and i % cfg.lr_step == 0:  # 调整学习速率
            lrScheduler.step()
            logger.info("lr step")
        # torch.cuda.empty_cache()
    print('end the training')