Ejemplo n.º 1
0
def get_list_of_wav_paths(data_version: str, n_augmentations: [int, str] = 0) -> tuple:
    """
    Retrieves the list of filepaths that belong to train, validation and test
    :param n_augmentations: specify the number of augmentations to use or specify "all" to load all the available ones
    (int|str)
    :param data_version: specifies the version of the data to use (str {"0.01", "0.02"})
    :return: list of training paths, list of validation paths and list of test paths (list of lists)
    """
    folders = [get_training_data_path(data_version=data_version)]
    if type(n_augmentations) == int:
        folders += [get_augmented_data_folder(data_version=data_version, folder=str(f)) for f in range(n_augmentations)]
    elif n_augmentations == "all":
        base = get_augmented_data_path(data_version=data_version)
        folders += [os.path.join(base, f) for f in os.listdir(base)]
    else:
        raise ValueError(f"'n_augmentations' parameter value not recognized as a valid argument ('all'|int): {n_augmentations}")

    for path in folders:
        if len(os.listdir(path)) == 0:
            warnings.warn(f"Attempting to load files from an empty folder: {path}")

    list_test = open(os.path.join(get_training_data_path(data_version=data_version), "testing_list.txt"))
    list_test = list(
        map(lambda x: os.path.normpath(os.path.join(get_training_data_path(data_version=data_version), x.strip())),
            list_test))

    list_val = open(os.path.join(get_training_data_path(data_version=data_version), "validation_list.txt"))
    list_val = list(
        map(lambda x: os.path.normpath(os.path.join(get_training_data_path(data_version=data_version), x.strip())),
            list_val))

    list_train = flatten([list(recursive_listdir(os.path.normpath(folder))) for folder in folders])
    list_train = list(filter(lambda p: "background_noise" not in p and p.endswith("wav"), list_train))
    list_train = np.setdiff1d(list_train, list_test + list_val).tolist()
    return list_train, list_val, list_test
Ejemplo n.º 2
0
def decompress_dataset(data_version: str):
    """
    Retrieves the downloaded data and decompresses it
    :param data_version: specifies the version of the data to use (str {"0.01", "0.02"})
    """
    fname = get_dataset_filepath(data_version=data_version)
    assert os.path.exists(fname)
    tar = tarfile.open(fname, "r:gz")
    tar.extractall(path=get_training_data_path(data_version=data_version))
    tar.close()
Ejemplo n.º 3
0
def load_random_real_noise_clip(data_version: str) -> np.array:
    """
    Loads a random noise clip from the _background_noise_ folder and returns it
    :param data_version: specifies the version of the data to use (str {"0.01", "0.02"})
    :return: real noise clip (np.array)
    """
    path = os.path.join(get_training_data_path(data_version=data_version), "_background_noise_")
    filename = random.choice(list(filter(lambda x: x.endswith(".wav"), os.listdir(path))))
    _, clip = read_wavfile(os.path.join(path, filename))
    return clip
Ejemplo n.º 4
0
def load_real_noise_clips(data_version: str) -> np.array:
    """
    Loads all the available real noise clips, which are located under the _background_noise_ folder
    :param data_version: specifies the version of the data to use (str {"0.01", "0.02"})
    :return: list of real noise clips (list of np.array)
    """
    clips = []
    path = os.path.join(get_training_data_path(data_version=data_version), "_background_noise_")
    for filename in filter(lambda x: x.endswith(".wav"), os.listdir(path)):
        _, wav = read_wavfile(os.path.join(path, filename))
        clips.append(wav)
    return clips
Ejemplo n.º 5
0
 def test_download_and_decompress_data(self):
     data_version = "0.02"
     filepath = get_dataset_filepath(data_version=data_version)
     if os.path.exists(filepath):
         self.assertTrue(True)  # skip
     else:  # this will run in Travis
         download_dataset(data_version=data_version)
         self.assertTrue(os.path.exists(filepath))
         decompress_dataset(data_version=data_version)
         self.assertLess(
             10,
             len(
                 os.listdir(
                     get_training_data_path(data_version=data_version))))
Ejemplo n.º 6
0
 def test_get_training_data_path(self):
     path = get_training_data_path("unit_testing")
     self.assertTrue(os.path.exists(path))
     os.rmdir(path)