Esempio n. 1
0
def prepare_iterable(
        db, dataset: str, batch_size, return_keys=None, prefetch=True,
        iterator_slice=None
):
    audio_keys = [OBSERVATION, SPEECH_SOURCE]
    audio_reader = AudioReader(audio_keys=audio_keys, read_fn=db.read_fn)
    iterator = db.get_iterator_by_names(dataset)

    if iterator_slice is not None:
        iterator = iterator[iterator_slice]

    iterator = (
        iterator
        .map(audio_reader)
        .map(partial(pre_batch_transform, return_keys=return_keys))
        .shuffle(reshuffle=False)
        .batch(batch_size)
        .map(lambda batch: sorted(
            batch,
            key=lambda example: example["num_frames"],
            reverse=True,
        ))
        .map(pt.data.utils.collate_fn)
        .map(post_batch_transform)
        .tile(reps=50, shuffle=True)  # Simulates reshuffle to some degree
    )

    if prefetch:
        iterator = iterator.prefetch(4, 8)

    return iterator
Esempio n. 2
0
def get_validation_iterator(database: JsonAudioDatabase):
    # AudioReader is a specialized function to read audio organized
    # in a json as described in pb.database.database
    audio_reader = AudioReader(
        audio_keys=[K.OBSERVATION, K.NOISE_IMAGE, K.SPEECH_IMAGE])
    val_iterator = database.get_dataset_validation()
    return val_iterator.map(audio_reader)\
        .map(change_example_structure)\
        .prefetch(num_workers=4, buffer_size=4)
Esempio n. 3
0
 def read_audio(self, example):
     """Function to be mapped on an iterator."""
     return AudioReader(audio_keys=self.opts.audio_keys,
                        read_fn=self.database.read_fn)(example)