def _get_dataset(self, name=None): """ The private get_dataset function allows for databases to overwrite get_dataset and adding a map function without recursively calling the map function in case of multiple dataset names. """ if name is None: raise TypeError( f'Missing dataset_name, use e.g.: {self.dataset_names}') if isinstance(name, str): pass elif isinstance(name, typing.Iterable) and not isinstance(name, dict): datasets = [self._get_dataset(n) for n in name] return lazy_dataset.concatenate(*datasets) else: raise TypeError( f'Argument type {type(name)} of {name} is not allowed!' f'Expected are str, list or tuple.') # Resulting dataset is immutable anyway due to pickle in # `lazy_dataset.from_dict`. This code here avoids to store the # resulting dataset more than once in memory. Discuss with CBJ for # details. try: return self._dataset_weak_ref_dict[name] except KeyError: pass examples = self.get_examples(name) ds = lazy_dataset.from_dict(examples, name=name) self._dataset_weak_ref_dict[name] = ds return ds
def prepare_iterable(db, datasets: List[str], batch_size, chunk_size, prefetch=True, iterator_slice=None, shuffle=True): """ This is re-used in the evaluate script """ # Create an iterator from the datasets (a simple concatenation of the # single datasets) if isinstance(datasets, str): datasets = datasets.split(',') iterator = db.get_dataset(datasets) # TODO: this does not make too much sense when we have multiple datasets if iterator_slice is not None: iterator = iterator[iterator_slice] # Determine the number of speakers in each example def add_num_speakers(example): example.update(num_speakers=len(example['speaker_id'])) return example iterator = iterator.map(add_num_speakers) # Group iterators by number of speakers so that all examples in a batch # have the same number of speakers iterators = list(iterator.groupby(lambda x: x['num_speakers']).values()) chunker = RandomChunkSingle(chunk_size, chunk_keys=('y', 's'), axis=-1) iterators = [ iterator.map(pre_batch_transform).map(chunker).shuffle( reshuffle=shuffle).batch(batch_size).map( pt.data.batch.Sorter('num_samples')).map( pt.data.utils.collate_fn) for iterator in iterators ] iterator = lazy_dataset.concatenate(*iterators).shuffle(reshuffle=shuffle) # FilterExceptions are only raised inside the chunking code if the # example is too short. If min_length <= 0 or chunk_size == -1, no filter # exception is raised. catch_exception = chunker.chunk_size != -1 and chunker.min_length > 0 if prefetch: iterator = iterator.prefetch(8, 16, catch_filter_exception=catch_exception) elif catch_exception: iterator = iterator.catch() return iterator
def get_datasets(self, dataset_names, use_weakref=True): """ Returns a single Iterator over specified datasets. Adds the example_id and dataset_name to each example dict. :param dataset_names: list or str specifying the datasets of interest. If None an iterator over the complete databases will be returned. :return: """ dataset_names = to_list(dataset_names, item_type=str) iterators = list() for dataset_name in dataset_names: if use_weakref: try: it = self._iterator_weak_ref_dict[dataset_name] except KeyError: pass else: iterators.append(it) continue try: examples = self._get_dataset_from_database_dict(dataset_name) except KeyError: import difflib similar = difflib.get_close_matches( dataset_name, self.dataset_names, n=5, cutoff=0, ) raise KeyError(dataset_name, f'close_matches: {similar}', self) if len(examples) == 0: # When somebody need empty datasets, add an option to this # function to allow empty datasets. raise RuntimeError( f'The requested dataset {dataset_name!r} is empty. ') for example_id in examples.keys(): examples[example_id][EXAMPLE_ID] = example_id examples[example_id][DATASET_NAME] = dataset_name # Convert values to binary, because deepcopy on binary is faster # This is important for CHiME5 ds = lazy_dataset.from_dict(examples) if use_weakref: self._iterator_weak_ref_dict[dataset_name] = ds iterators.append(ds) return lazy_dataset.concatenate(*iterators)
def test_concatenate_function(): ds_train = get_dataset() ds_predict = get_dataset_predict() ds = lazy_dataset.concatenate(ds_train, ds_predict) example_ids = [e['example_id'] for e in ds] assert example_ids == [f'example_id_{i}' for i in range(1, 6)] assert ds['example_id_1']['example_id'] == 'example_id_1' assert ds['example_id_5']['example_id'] == 'example_id_5' assert ds[0]['example_id'] == 'example_id_1' assert ds[-1]['example_id'] == 'example_id_5' assert ds[:1][0]['example_id'] == 'example_id_1' ds = lazy_dataset.concatenate([ds_train, ds_predict]) example_ids = [e['example_id'] for e in ds] assert example_ids == [f'example_id_{i}' for i in range(1, 6)] assert ds['example_id_1']['example_id'] == 'example_id_1' assert ds['example_id_5']['example_id'] == 'example_id_5' assert ds[0]['example_id'] == 'example_id_1' assert ds[-1]['example_id'] == 'example_id_5' assert ds[:1][0]['example_id'] == 'example_id_1'
def prepare_iterable(db, datasets: List[str], batch_size, chunk_size, prefetch=True, iterator_slice=None, shuffle=True): """ This is re-used in the evaluate script """ # Create an iterator from the datasets (a simple concatenation of the # single datasets) and determine the number of speakers in each example if isinstance(datasets, str): datasets = datasets.split(',') iterator = db.get_dataset(datasets).map( lambda x: (x.update(num_speakers=len(x['speaker_id'])), x)[1]) # TODO: this does not make too much sense when we have multiple datasets if iterator_slice is not None: iterator = iterator[iterator_slice] # This iterators = list(iterator.groupby(lambda x: x['num_speakers']).values()) iterators = [ iterator.map(pre_batch_transform).map( RandomChunkSingle( chunk_size, chunk_keys=('y', 's'), axis=-1)).shuffle(reshuffle=shuffle).batch(batch_size).map( lambda batch: sorted( batch, key=lambda example: example['num_samples'], reverse=True, )).map(pt.data.utils.collate_fn) for iterator in iterators ] iterator = lazy_dataset.concatenate(*iterators).shuffle(reshuffle=shuffle) if prefetch: iterator = iterator.prefetch(8, 16, catch_filter_exception=True) elif chunk_size > 0: iterator = iterator.catch() return iterator
def test_concatenate_function_raises_on_non_dataset_instances(): ds_train = get_dataset() not_a_ds = dict() with pytest.raises(TypeError): lazy_dataset.concatenate(ds_train, not_a_ds)
def test_concatenate_function_raises_on_empty_list(): with pytest.raises(ValueError): lazy_dataset.concatenate()
def get_datasets(use_noisy, split, fold, curated_reps, mixup_probs, extractor, augmenter, num_workers, batch_size, prefetch_buffer, max_padding_rate, bucket_expiration, event_bucketing, debug): # prepare database database_json = jsons_dir / f'fsd_kaggle_2019_split{split}.json' db = JsonDatabase(database_json) def add_noisy_flag(example): example['is_noisy'] = example['dataset'] != 'train_curated' return example extractor = Extractor(**extractor) augmenter = Augmenter(extractor=extractor, **augmenter) curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag) extractor.initialize_labels(curated_train_data) if debug: curated_train_data = curated_train_data.shuffle( )[:len(curated_train_data) // 10] extractor.initialize_norm(dataset_name='train_curated', dataset=curated_train_data, max_workers=num_workers) if fold is not None: curated_train_data, validation_set = split_dataset(curated_train_data, fold=fold, seed=0) else: validation_set = None if use_noisy: noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag) if debug: noisy_train_data = noisy_train_data.shuffle( )[:len(noisy_train_data) // 10] extractor.initialize_norm(dataset_name='train_noisy', dataset=noisy_train_data, max_workers=num_workers) training_set = lazy_dataset.concatenate(curated_train_data, noisy_train_data) else: training_set = curated_train_data batch_norm_tuning_set = training_set if mixup_probs is not None: training_set = MixUpDataset(training_set, training_set, mixin_probs=mixup_probs) if curated_reps > 0: print('curated reps:', curated_reps) curated_train_data = lazy_dataset.from_dict({ f'{example["example_id"]}_{i}': example for i in range(curated_reps) for example in curated_train_data }) if mixup_probs is not None: curated_train_data = MixUpDataset(curated_train_data, curated_train_data, mixin_probs=mixup_probs) training_set = lazy_dataset.concatenate(training_set, curated_train_data) print('Length of training set', len(training_set)) bucket_cls = EventTimeSeriesBucket if event_bucketing \ else DynamicTimeSeriesBucket def prepare_iterable(dataset, drop_incomplete=False): return dataset.prefetch( num_workers=num_workers, buffer_size=prefetch_buffer, catch_filter_exception=True).batch_dynamic_bucket( bucket_cls=bucket_cls, batch_size=batch_size, len_key='seq_len', max_padding_rate=max_padding_rate, expiration=bucket_expiration, drop_incomplete=drop_incomplete, sort_key='seq_len', reverse_sort=True).map(Collate()) training_set = prepare_iterable( training_set.map(augmenter).shuffle(reshuffle=True), drop_incomplete=True) batch_norm_tuning_set = prepare_iterable( batch_norm_tuning_set.map(extractor), drop_incomplete=True) if validation_set is not None: validation_set = prepare_iterable(validation_set.map(extractor)) return training_set, validation_set, batch_norm_tuning_set
def get_datasets( use_noisy, split, relabeled, fold, curated_reps, mixup_probs, audio_reader, stft, mel_transform, augmenter, num_workers, batch_size, prefetch_buffer, max_padding_rate, bucket_expiration, event_bucketing, debug ): # prepare database database_json = jsons_dir / \ f'fsd_kaggle_2019_split{split}{"_relabeled" if relabeled else ""}.json' db = JsonDatabase(database_json) def add_noisy_flag(example): example['is_noisy'] = example['dataset'] != 'train_curated' return example audio_reader = AudioReader(**audio_reader) stft = STFT(**stft) mel_transform = MelTransform(**mel_transform) normalizer = Normalizer(storage_dir=str(storage_dir)) augmenter = Augmenter(**augmenter) curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag) event_encoder = MultiHotLabelEncoder( label_key='events', storage_dir=storage_dir ) event_encoder.initialize_labels( dataset=curated_train_data, verbose=True ) if debug: curated_train_data = curated_train_data.shuffle()[:500] normalizer.initialize_norm( dataset_name='train_curated', dataset=curated_train_data.map(audio_reader).map(stft).map(mel_transform), max_workers=num_workers, ) if fold is not None: curated_train_data, validation_set = split_dataset( curated_train_data, fold=fold, seed=0 ) else: validation_set = None if use_noisy: noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag) if debug: noisy_train_data = noisy_train_data.shuffle()[:500] normalizer.initialize_norm( dataset_name='train_noisy', dataset=noisy_train_data.map(audio_reader).map(stft).map(mel_transform), max_workers=num_workers, ) training_set = lazy_dataset.concatenate(curated_train_data, noisy_train_data) else: training_set = curated_train_data batch_norm_tuning_set = training_set if mixup_probs is not None: training_set = MixUpDataset( training_set, training_set, mixin_probs=mixup_probs ) if curated_reps > 0: print('curated reps:', curated_reps) curated_train_data = lazy_dataset.from_dict({ f'{example["example_id"]}_{i}': example for i in range(curated_reps) for example in curated_train_data }) if mixup_probs is not None: curated_train_data = MixUpDataset( curated_train_data, curated_train_data, mixin_probs=mixup_probs ) training_set = lazy_dataset.concatenate( training_set, curated_train_data ) print('Length of training set', len(training_set)) print('Length of validation set', len(validation_set)) def finalize(example): x = example['features'] example_ = { 'example_id': example['example_id'], 'dataset': example['dataset'], 'is_noisy': np.array(example['is_noisy']).astype(np.float32), 'features': x.astype(np.float32), 'seq_len': x.shape[1], } if 'events' in example: example_['events'] = example['events'] return example_ bucket_cls = EventTimeSeriesBucket if event_bucketing \ else DynamicTimeSeriesBucket def prepare_iterable(dataset, drop_incomplete=False): return dataset.map(event_encoder).map(finalize).prefetch( num_workers=num_workers, buffer_size=prefetch_buffer, catch_filter_exception=True ).batch_dynamic_bucket( bucket_cls=bucket_cls, batch_size=batch_size, len_key='seq_len', max_padding_rate=max_padding_rate, expiration=bucket_expiration, drop_incomplete=drop_incomplete, sort_key='seq_len', reverse_sort=True ).map(Collate()) training_set = prepare_iterable( training_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer).map(augmenter).shuffle(reshuffle=True), drop_incomplete=True ) batch_norm_tuning_set = prepare_iterable( batch_norm_tuning_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer), drop_incomplete=True ) if validation_set is not None: validation_set = prepare_iterable( validation_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer) ) return training_set, validation_set, batch_norm_tuning_set