def get_dataset(data_dir, voc_type, max_len, num_samples): if isinstance(data_dir, list): dataset_list = [] for data_dir_ in data_dir: dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples)) dataset = ConcatDataset(dataset_list) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) print('total image: ', len(dataset)) return dataset
def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio): if isinstance(data_dir, list): dataset_list = [] for data_dir_ in data_dir: dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples)) dataset = ConcatDataset(dataset_list) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) print('total image: ', len(dataset)) if is_train: data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=True, pin_memory=True, drop_last=True, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) else: data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader
def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio, augment=False): transform = albu.Compose([ albu.RGBShift(p=0.5), albu.RandomBrightnessContrast(p=0.5), albu.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=0.5) ]) if augment else None if isinstance(data_dir, list): dataset = ConcatDataset([ LmdbDataset(data_dir_, voc_type, max_len, num_samples, transform) for data_dir_ in data_dir ]) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples, transform) print('total image: ', len(dataset)) data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=is_train, pin_memory=True, drop_last=is_train, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader
def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio, n_max_samples=-1): if isinstance(data_dir, list): dataset_list = [] for data_dir_ in data_dir: dataset_list.append( LmdbDataset(data_dir_, voc_type, max_len, num_samples)) dataset = ConcatDataset(dataset_list) else: dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) print('total image: ', len(dataset)) if n_max_samples > 0: n_all_samples = len(dataset) assert n_max_samples < n_all_samples # make sample indices static for every run sample_indices_cache_file = '.sample_indices.cache.pkl' if os.path.exists(sample_indices_cache_file): with open(sample_indices_cache_file, 'rb') as fin: sample_indices = pickle.load(fin) print('load sample indices from sample_indices_cache_file: ', n_max_samples) else: sample_indices = np.random.choice(n_all_samples, n_max_samples, replace=False) with open(sample_indices_cache_file, 'wb') as fout: pickle.dump(sample_indices, fout) print('random sample: ', n_max_samples) sub_sampler = SubsetRandomSampler(sample_indices) else: sub_sampler = None if is_train: data_loader = DataLoader( dataset, batch_size=batch_size, num_workers=workers, sampler=sub_sampler, shuffle=(True if sub_sampler is None else False), pin_memory=True, drop_last=True, collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) else: data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True, drop_last=False, collate_fn=AlignCollate( imgH=height, imgW=width, keep_ratio=keep_ratio)) return dataset, data_loader