コード例 #1
0
ファイル: main.py プロジェクト: peternara/AutoSTR-OCR
def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size,
                   workers, is_train, keep_ratio):
    num_synthetic_dataset = len(synthetic_dataset)
    num_real_dataset = len(real_dataset)

    synthetic_indices = list(np.random.permutation(num_synthetic_dataset))
    synthetic_indices = synthetic_indices[num_real_dataset:]
    real_indices = list(
        np.random.permutation(num_real_dataset) + num_synthetic_dataset)
    concated_indices = synthetic_indices + real_indices
    assert len(concated_indices) == num_synthetic_dataset

    sampler = SubsetRandomSampler(concated_indices)
    concated_dataset = ConcatDataset([synthetic_dataset, real_dataset])
    print('total image: ', len(concated_dataset))

    data_loader = DataLoader(concated_dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             shuffle=False,
                             pin_memory=True,
                             drop_last=True,
                             sampler=sampler,
                             collate_fn=AlignCollate(imgH=height,
                                                     imgW=width,
                                                     keep_ratio=keep_ratio))
    return concated_dataset, data_loader
コード例 #2
0
def get_dataset(data_dir, voc_type, max_len, num_samples):
  if isinstance(data_dir, list):
    dataset_list = []
    for data_dir_ in data_dir:
      dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
    dataset = ConcatDataset(dataset_list)
  else:
    dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples)
  print('total image: ', len(dataset))
  return dataset
コード例 #3
0
ファイル: main.py プロジェクト: kkalla/SEED
def get_data_txt(data_dir,
                 gt_file_path,
                 embed_dir,
                 voc_type,
                 max_len,
                 num_samples,
                 height,
                 width,
                 batch_size,
                 workers,
                 is_train,
                 keep_ratio):
    if isinstance(data_dir, list) and len(data_dir) > 1:
        dataset_list = []
        for data_dir_, gt_file_, embed_dir_ in zip(data_dir,
                                                   gt_file_path, embed_dir):
            # dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples))
            dataset_list.append(CustomDataset(
                data_dir_, gt_file_, embed_dir_, voc_type, max_len, num_samples))
        dataset = ConcatDataset(dataset_list)
    else:
        # dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples)
        dataset = CustomDataset(data_dir, gt_file_path,
                                embed_dir, voc_type, max_len, num_samples)
    print('total image: ', len(dataset))

    if is_train:
        """
        data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers,
          shuffle=True, pin_memory=True, drop_last=True,
          collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))
        """
        data_loader = DataLoader(dataset, batch_size=batch_size,
                                 num_workers=workers,
                                 shuffle=True, pin_memory=True, drop_last=True,
                                 collate_fn=AlignCollate(
                                     imgH=height,
                                     imgW=width,
                                     keep_ratio=keep_ratio))
    else:
        data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers,
                                 shuffle=False, pin_memory=True, drop_last=False,
                                 collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio))

    return dataset, data_loader
コード例 #4
0
ファイル: main.py プロジェクト: kkalla/SEED
def get_data_lmdb(data_dir, voc_type,
                  max_len, num_samples,
                  height, width,
                  batch_size, workers,
                  is_train, keep_ratio,
                  voc_file=None):
    if isinstance(data_dir, list):
        dataset_list = []
        for data_dir_ in data_dir:
            dataset_list.append(LmdbDataset(
                data_dir_, voc_type, max_len, num_samples, voc_file=voc_file))
        dataset = ConcatDataset(dataset_list)
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len,
                              num_samples, voc_file=voc_file)
    print('total image: ', len(dataset))

    if is_train:
        data_loader = DataLoader(
            dataset, batch_size=batch_size,
            num_workers=workers,
            shuffle=True, pin_memory=True, drop_last=True,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio)
        )
    else:
        data_loader = DataLoader(
            dataset, batch_size=batch_size,
            num_workers=workers,
            shuffle=False, pin_memory=True,
            drop_last=False,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio)
        )

    return dataset, data_loader
コード例 #5
0
ファイル: main.py プロジェクト: aparpara/aster.pytorch
def get_data(data_dir,
             voc_type,
             max_len,
             num_samples,
             height,
             width,
             batch_size,
             workers,
             is_train,
             keep_ratio,
             augment=False):
    transform = albu.Compose([
        albu.RGBShift(p=0.5),
        albu.RandomBrightnessContrast(p=0.5),
        albu.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=0.5)
    ]) if augment else None

    if isinstance(data_dir, list):
        dataset = ConcatDataset([
            LmdbDataset(data_dir_, voc_type, max_len, num_samples, transform)
            for data_dir_ in data_dir
        ])
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples,
                              transform)
    print('total image: ', len(dataset))

    data_loader = DataLoader(dataset,
                             batch_size=batch_size,
                             num_workers=workers,
                             shuffle=is_train,
                             pin_memory=True,
                             drop_last=is_train,
                             collate_fn=AlignCollate(imgH=height,
                                                     imgW=width,
                                                     keep_ratio=keep_ratio))

    return dataset, data_loader
コード例 #6
0
ファイル: main.py プロジェクト: peternara/AutoSTR-OCR
def get_data(data_dir,
             voc_type,
             max_len,
             num_samples,
             height,
             width,
             batch_size,
             workers,
             is_train,
             keep_ratio,
             n_max_samples=-1):
    if isinstance(data_dir, list):
        dataset_list = []
        for data_dir_ in data_dir:
            dataset_list.append(
                LmdbDataset(data_dir_, voc_type, max_len, num_samples))
        dataset = ConcatDataset(dataset_list)
    else:
        dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples)
    print('total image: ', len(dataset))

    if n_max_samples > 0:
        n_all_samples = len(dataset)
        assert n_max_samples < n_all_samples
        # make sample indices static for every run
        sample_indices_cache_file = '.sample_indices.cache.pkl'
        if os.path.exists(sample_indices_cache_file):
            with open(sample_indices_cache_file, 'rb') as fin:
                sample_indices = pickle.load(fin)
            print('load sample indices from sample_indices_cache_file: ',
                  n_max_samples)
        else:
            sample_indices = np.random.choice(n_all_samples,
                                              n_max_samples,
                                              replace=False)
            with open(sample_indices_cache_file, 'wb') as fout:
                pickle.dump(sample_indices, fout)
            print('random sample: ', n_max_samples)
        sub_sampler = SubsetRandomSampler(sample_indices)
    else:
        sub_sampler = None

    if is_train:
        data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=workers,
            sampler=sub_sampler,
            shuffle=(True if sub_sampler is None else False),
            pin_memory=True,
            drop_last=True,
            collate_fn=AlignCollate(imgH=height,
                                    imgW=width,
                                    keep_ratio=keep_ratio))
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 num_workers=workers,
                                 shuffle=False,
                                 pin_memory=True,
                                 drop_last=False,
                                 collate_fn=AlignCollate(
                                     imgH=height,
                                     imgW=width,
                                     keep_ratio=keep_ratio))

    return dataset, data_loader
コード例 #7
0
ファイル: main.py プロジェクト: tranbaohieu/SAFL_pytorch
def training_dataset(opt):
    return ConcatDataset([training_data_generator.get_training_data_generator(img_h=opt.imgH, img_w=img_w)])