Example #1
0
    def _get_dataset(self, name=None):
        """
        The private get_dataset function allows for databases to overwrite
        get_dataset and adding a map function without recursively calling
        the map function in case of multiple dataset names.
        """
        if name is None:
            raise TypeError(
                f'Missing dataset_name, use e.g.: {self.dataset_names}')

        if isinstance(name, str):
            pass
        elif isinstance(name, typing.Iterable) and not isinstance(name, dict):
            datasets = [self._get_dataset(n) for n in name]
            return lazy_dataset.concatenate(*datasets)
        else:
            raise TypeError(
                f'Argument type {type(name)} of {name} is not allowed!'
                f'Expected are str, list or tuple.')

        # Resulting dataset is immutable anyway due to pickle in
        # `lazy_dataset.from_dict`. This code here avoids to store the
        # resulting dataset more than once in memory. Discuss with CBJ for
        # details.
        try:
            return self._dataset_weak_ref_dict[name]
        except KeyError:
            pass

        examples = self.get_examples(name)
        ds = lazy_dataset.from_dict(examples, name=name)

        self._dataset_weak_ref_dict[name] = ds

        return ds
Example #2
0
def prepare_iterable(db,
                     datasets: List[str],
                     batch_size,
                     chunk_size,
                     prefetch=True,
                     iterator_slice=None,
                     shuffle=True):
    """
    This is re-used in the evaluate script
    """
    # Create an iterator from the datasets (a simple concatenation of the
    # single datasets)
    if isinstance(datasets, str):
        datasets = datasets.split(',')
    iterator = db.get_dataset(datasets)

    # TODO: this does not make too much sense when we have multiple datasets
    if iterator_slice is not None:
        iterator = iterator[iterator_slice]

    # Determine the number of speakers in each example
    def add_num_speakers(example):
        example.update(num_speakers=len(example['speaker_id']))
        return example

    iterator = iterator.map(add_num_speakers)

    # Group iterators by number of speakers so that all examples in a batch
    # have the same number of speakers
    iterators = list(iterator.groupby(lambda x: x['num_speakers']).values())

    chunker = RandomChunkSingle(chunk_size, chunk_keys=('y', 's'), axis=-1)
    iterators = [
        iterator.map(pre_batch_transform).map(chunker).shuffle(
            reshuffle=shuffle).batch(batch_size).map(
                pt.data.batch.Sorter('num_samples')).map(
                    pt.data.utils.collate_fn) for iterator in iterators
    ]

    iterator = lazy_dataset.concatenate(*iterators).shuffle(reshuffle=shuffle)

    # FilterExceptions are only raised inside the chunking code if the
    # example is too short. If min_length <= 0 or chunk_size == -1, no filter
    # exception is raised.
    catch_exception = chunker.chunk_size != -1 and chunker.min_length > 0
    if prefetch:
        iterator = iterator.prefetch(8,
                                     16,
                                     catch_filter_exception=catch_exception)
    elif catch_exception:
        iterator = iterator.catch()

    return iterator
Example #3
0
    def get_datasets(self, dataset_names, use_weakref=True):
        """
        Returns a single Iterator over specified datasets.

        Adds the example_id and dataset_name to each example dict.

        :param dataset_names: list or str specifying the datasets of interest.
            If None an iterator over the complete databases will be returned.
        :return:
        """
        dataset_names = to_list(dataset_names, item_type=str)
        iterators = list()
        for dataset_name in dataset_names:
            if use_weakref:
                try:
                    it = self._iterator_weak_ref_dict[dataset_name]
                except KeyError:
                    pass
                else:
                    iterators.append(it)
                    continue
            try:
                examples = self._get_dataset_from_database_dict(dataset_name)
            except KeyError:
                import difflib
                similar = difflib.get_close_matches(
                    dataset_name,
                    self.dataset_names,
                    n=5,
                    cutoff=0,
                )
                raise KeyError(dataset_name, f'close_matches: {similar}', self)
            if len(examples) == 0:
                # When somebody need empty datasets, add an option to this
                # function to allow empty datasets.
                raise RuntimeError(
                    f'The requested dataset {dataset_name!r} is empty. ')

            for example_id in examples.keys():
                examples[example_id][EXAMPLE_ID] = example_id
                examples[example_id][DATASET_NAME] = dataset_name

            # Convert values to binary, because deepcopy on binary is faster
            # This is important for CHiME5
            ds = lazy_dataset.from_dict(examples)

            if use_weakref:
                self._iterator_weak_ref_dict[dataset_name] = ds

            iterators.append(ds)

        return lazy_dataset.concatenate(*iterators)
Example #4
0
def test_concatenate_function():
    ds_train = get_dataset()
    ds_predict = get_dataset_predict()

    ds = lazy_dataset.concatenate(ds_train, ds_predict)
    example_ids = [e['example_id'] for e in ds]
    assert example_ids == [f'example_id_{i}' for i in range(1, 6)]

    assert ds['example_id_1']['example_id'] == 'example_id_1'
    assert ds['example_id_5']['example_id'] == 'example_id_5'
    assert ds[0]['example_id'] == 'example_id_1'
    assert ds[-1]['example_id'] == 'example_id_5'
    assert ds[:1][0]['example_id'] == 'example_id_1'

    ds = lazy_dataset.concatenate([ds_train, ds_predict])
    example_ids = [e['example_id'] for e in ds]
    assert example_ids == [f'example_id_{i}' for i in range(1, 6)]

    assert ds['example_id_1']['example_id'] == 'example_id_1'
    assert ds['example_id_5']['example_id'] == 'example_id_5'
    assert ds[0]['example_id'] == 'example_id_1'
    assert ds[-1]['example_id'] == 'example_id_5'
    assert ds[:1][0]['example_id'] == 'example_id_1'
Example #5
0
def prepare_iterable(db,
                     datasets: List[str],
                     batch_size,
                     chunk_size,
                     prefetch=True,
                     iterator_slice=None,
                     shuffle=True):
    """
    This is re-used in the evaluate script
    """
    # Create an iterator from the datasets (a simple concatenation of the
    # single datasets) and determine the number of speakers in each example
    if isinstance(datasets, str):
        datasets = datasets.split(',')

    iterator = db.get_dataset(datasets).map(
        lambda x: (x.update(num_speakers=len(x['speaker_id'])), x)[1])

    # TODO: this does not make too much sense when we have multiple datasets
    if iterator_slice is not None:
        iterator = iterator[iterator_slice]

    # This
    iterators = list(iterator.groupby(lambda x: x['num_speakers']).values())

    iterators = [
        iterator.map(pre_batch_transform).map(
            RandomChunkSingle(
                chunk_size, chunk_keys=('y', 's'),
                axis=-1)).shuffle(reshuffle=shuffle).batch(batch_size).map(
                    lambda batch: sorted(
                        batch,
                        key=lambda example: example['num_samples'],
                        reverse=True,
                    )).map(pt.data.utils.collate_fn) for iterator in iterators
    ]

    iterator = lazy_dataset.concatenate(*iterators).shuffle(reshuffle=shuffle)

    if prefetch:
        iterator = iterator.prefetch(8, 16, catch_filter_exception=True)
    elif chunk_size > 0:
        iterator = iterator.catch()

    return iterator
Example #6
0
def test_concatenate_function_raises_on_non_dataset_instances():
    ds_train = get_dataset()
    not_a_ds = dict()
    with pytest.raises(TypeError):
        lazy_dataset.concatenate(ds_train, not_a_ds)
Example #7
0
def test_concatenate_function_raises_on_empty_list():
    with pytest.raises(ValueError):
        lazy_dataset.concatenate()
Example #8
0
def get_datasets(use_noisy, split, fold, curated_reps, mixup_probs, extractor,
                 augmenter, num_workers, batch_size, prefetch_buffer,
                 max_padding_rate, bucket_expiration, event_bucketing, debug):
    # prepare database
    database_json = jsons_dir / f'fsd_kaggle_2019_split{split}.json'
    db = JsonDatabase(database_json)

    def add_noisy_flag(example):
        example['is_noisy'] = example['dataset'] != 'train_curated'
        return example

    extractor = Extractor(**extractor)
    augmenter = Augmenter(extractor=extractor, **augmenter)

    curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag)
    extractor.initialize_labels(curated_train_data)
    if debug:
        curated_train_data = curated_train_data.shuffle(
        )[:len(curated_train_data) // 10]
    extractor.initialize_norm(dataset_name='train_curated',
                              dataset=curated_train_data,
                              max_workers=num_workers)

    if fold is not None:
        curated_train_data, validation_set = split_dataset(curated_train_data,
                                                           fold=fold,
                                                           seed=0)
    else:
        validation_set = None

    if use_noisy:
        noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag)
        if debug:
            noisy_train_data = noisy_train_data.shuffle(
            )[:len(noisy_train_data) // 10]
        extractor.initialize_norm(dataset_name='train_noisy',
                                  dataset=noisy_train_data,
                                  max_workers=num_workers)
        training_set = lazy_dataset.concatenate(curated_train_data,
                                                noisy_train_data)
    else:
        training_set = curated_train_data
    batch_norm_tuning_set = training_set

    if mixup_probs is not None:
        training_set = MixUpDataset(training_set,
                                    training_set,
                                    mixin_probs=mixup_probs)
    if curated_reps > 0:
        print('curated reps:', curated_reps)
        curated_train_data = lazy_dataset.from_dict({
            f'{example["example_id"]}_{i}': example
            for i in range(curated_reps) for example in curated_train_data
        })
        if mixup_probs is not None:
            curated_train_data = MixUpDataset(curated_train_data,
                                              curated_train_data,
                                              mixin_probs=mixup_probs)
        training_set = lazy_dataset.concatenate(training_set,
                                                curated_train_data)

    print('Length of training set', len(training_set))

    bucket_cls = EventTimeSeriesBucket if event_bucketing \
        else DynamicTimeSeriesBucket

    def prepare_iterable(dataset, drop_incomplete=False):
        return dataset.prefetch(
            num_workers=num_workers,
            buffer_size=prefetch_buffer,
            catch_filter_exception=True).batch_dynamic_bucket(
                bucket_cls=bucket_cls,
                batch_size=batch_size,
                len_key='seq_len',
                max_padding_rate=max_padding_rate,
                expiration=bucket_expiration,
                drop_incomplete=drop_incomplete,
                sort_key='seq_len',
                reverse_sort=True).map(Collate())

    training_set = prepare_iterable(
        training_set.map(augmenter).shuffle(reshuffle=True),
        drop_incomplete=True)
    batch_norm_tuning_set = prepare_iterable(
        batch_norm_tuning_set.map(extractor), drop_incomplete=True)
    if validation_set is not None:
        validation_set = prepare_iterable(validation_set.map(extractor))

    return training_set, validation_set, batch_norm_tuning_set
Example #9
0
def get_datasets(
        use_noisy, split, relabeled, fold, curated_reps, mixup_probs,
        audio_reader, stft, mel_transform, augmenter, num_workers, batch_size,
        prefetch_buffer, max_padding_rate, bucket_expiration, event_bucketing,
        debug
):
    # prepare database
    database_json = jsons_dir / \
        f'fsd_kaggle_2019_split{split}{"_relabeled" if relabeled else ""}.json'
    db = JsonDatabase(database_json)

    def add_noisy_flag(example):
        example['is_noisy'] = example['dataset'] != 'train_curated'
        return example

    audio_reader = AudioReader(**audio_reader)
    stft = STFT(**stft)
    mel_transform = MelTransform(**mel_transform)
    normalizer = Normalizer(storage_dir=str(storage_dir))
    augmenter = Augmenter(**augmenter)

    curated_train_data = db.get_dataset('train_curated').map(add_noisy_flag)

    event_encoder = MultiHotLabelEncoder(
        label_key='events', storage_dir=storage_dir
    )
    event_encoder.initialize_labels(
        dataset=curated_train_data, verbose=True
    )

    if debug:
        curated_train_data = curated_train_data.shuffle()[:500]

    normalizer.initialize_norm(
        dataset_name='train_curated',
        dataset=curated_train_data.map(audio_reader).map(stft).map(mel_transform),
        max_workers=num_workers,
    )

    if fold is not None:
        curated_train_data, validation_set = split_dataset(
            curated_train_data, fold=fold, seed=0
        )
    else:
        validation_set = None

    if use_noisy:
        noisy_train_data = db.get_dataset('train_noisy').map(add_noisy_flag)
        if debug:
            noisy_train_data = noisy_train_data.shuffle()[:500]

        normalizer.initialize_norm(
            dataset_name='train_noisy',
            dataset=noisy_train_data.map(audio_reader).map(stft).map(mel_transform),
            max_workers=num_workers,
        )
        training_set = lazy_dataset.concatenate(curated_train_data, noisy_train_data)
    else:
        training_set = curated_train_data
    batch_norm_tuning_set = training_set

    if mixup_probs is not None:
        training_set = MixUpDataset(
            training_set, training_set, mixin_probs=mixup_probs
        )
    if curated_reps > 0:
        print('curated reps:', curated_reps)
        curated_train_data = lazy_dataset.from_dict({
            f'{example["example_id"]}_{i}': example
            for i in range(curated_reps)
            for example in curated_train_data
        })
        if mixup_probs is not None:
            curated_train_data = MixUpDataset(
                curated_train_data, curated_train_data, mixin_probs=mixup_probs
            )
        training_set = lazy_dataset.concatenate(
            training_set, curated_train_data
        )

    print('Length of training set', len(training_set))
    print('Length of validation set', len(validation_set))

    def finalize(example):
        x = example['features']
        example_ = {
            'example_id': example['example_id'],
            'dataset': example['dataset'],
            'is_noisy': np.array(example['is_noisy']).astype(np.float32),
            'features': x.astype(np.float32),
            'seq_len': x.shape[1],
        }
        if 'events' in example:
            example_['events'] = example['events']
        return example_

    bucket_cls = EventTimeSeriesBucket if event_bucketing \
        else DynamicTimeSeriesBucket

    def prepare_iterable(dataset, drop_incomplete=False):
        return dataset.map(event_encoder).map(finalize).prefetch(
            num_workers=num_workers, buffer_size=prefetch_buffer,
            catch_filter_exception=True
        ).batch_dynamic_bucket(
            bucket_cls=bucket_cls, batch_size=batch_size, len_key='seq_len',
            max_padding_rate=max_padding_rate, expiration=bucket_expiration,
            drop_incomplete=drop_incomplete, sort_key='seq_len',
            reverse_sort=True
        ).map(Collate())

    training_set = prepare_iterable(
        training_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer).map(augmenter).shuffle(reshuffle=True),
        drop_incomplete=True
    )
    batch_norm_tuning_set = prepare_iterable(
        batch_norm_tuning_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer),
        drop_incomplete=True
    )
    if validation_set is not None:
        validation_set = prepare_iterable(
            validation_set.map(audio_reader).map(stft).map(mel_transform).map(normalizer)
        )

    return training_set, validation_set, batch_norm_tuning_set