def __init__(self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False): if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): download_url(url, root) extract_archive(archive, self._path) walker = walk_files(self._path, suffix=".wav", prefix=True) walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) self._walker = list(walker)
def __init__(self): self.path = './data/' walker = walk_files(self.path, suffix=self._ext_audio, prefix=False, remove_suffix=False) self._walker = list(walker)
def __init__(self, root: str) -> None: self._path = os.path.join(root, 'edus80') walker = walk_files( self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True ) self._walker = list(walker)
def test_walk_files(self): """walk_files should traverse files in alphabetical order""" n_ites = 0 for i, path in enumerate(dataset_utils.walk_files(self.root, '.txt', prefix=True)): found = os.path.join(self.root, path) assert found == self.expected[i] n_ites += 1 assert n_ites == len(self.expected)
def __init__( self, root: str, url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, subset: Any = None, ) -> None: # super(GTZAN, self).__init__() self.root = root self.url = url self.folder_in_archive = folder_in_archive self.download = download self.subset = subset assert subset is None or subset in [ "training", "validation", "testing" ], ("When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}.") archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) if self.subset is None: walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker) else: if self.subset == "training": self._walker = filtered_train elif self.subset == "validation": self._walker = filtered_valid elif self.subset == "testing": self._walker = filtered_test
def __init__( self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False, transform=None, target_transform=None, return_dict=False, ): if not return_dict: warnings.warn( "In the next version, the item returned will be a dictionary. " "Please use `return_dict=True` to enable this behavior now, " "and suppress this warning.", DeprecationWarning, ) if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please remove the option `transform=True` and " "`target_transform=True` to suppress this warning.", DeprecationWarning, ) self.transform = transform self.target_transform = target_transform self.return_dict = return_dict archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): download_url(url, root) extract_archive(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker)
def __init__(self, root: str, url: str, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False) -> None: if url in [ "dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", "train-clean-360", "train-other-500", ]: ext_archive = ".tar.gz" base_url = "http://www.openslr.org/resources/12/" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.split(".")[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum) extract_archive(archive) audio_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128) for root, dirs, files in os.walk(self._path): if len(files) != 0: for file in files: if file.split('.')[-1]==self._ext_wav.split('.')[-1]: file_audio = os.path.join(root, file) waveform, _ = torchaudio.load(file_audio) spec = audio_transforms(waveform) file_spec = os.path.join(root, file.split('.')[0]+ self._ext_wav) torch.save(spec, file_spec) walker = walk_files( self._path, suffix=self._ext_mel, prefix=False, remove_suffix=True ) self._walker = list(walker)
def __init__(self, root: str, url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, downsample: bool = False, transform: Any = None, target_transform: Any = None) -> None: if downsample: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please use `downsample=False` to enable this behavior now, ", "and suppress this warning.") if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please remove the option `transform=True` and " "`target_transform=True` to suppress this warning.") self.downsample = downsample self.transform = transform self.target_transform = target_transform archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) walker = filter(lambda w: self._except_folder not in w, walker) self._walker = list(walker)
def __init__(self, root_path: str) -> None: root_path.replace("\\", "/") self._path = root_path classes_file_path = os.path.join(root_path, CLASSES_FILE) walker = walk_files(self._path, suffix=".wav", prefix=True) walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) self._walker = list(walker) with open(classes_file_path, 'rb') as handle: self.classes = pickle.load(handle)
def __init__(self, root: str, url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, downsample: bool = False, transform: Any = None, target_transform: Any = None) -> None: if downsample: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please use `downsample=False` to enable this behavior now, " "and suppress this warning." ) if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please remove the option `transform=True` and " "`target_transform=True` to suppress this warning." ) self.downsample = downsample self.transform = transform self.target_transform = target_transform archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: raise RuntimeError( "This Dataset is no longer available. " "Please use `VCTK_092` class to download the latest version." ) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `VCTK_092` class " "with `download=True` to donwload the latest version." ) walker = walk_files( self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True ) walker = filter(lambda w: self._except_folder not in w, walker) self._walker = list(walker)
def tokenizer_librispeech(limit=10, path="../librispeech/LibriSpeech/", version = "train-clean-360"): ext_txt = ".trans.txt" ext_audio = ".flac" path = os.path.join(path, version) walker = walk_files(path, suffix=ext_audio, prefix=False, remove_suffix=True) walker = list(walker) all_sentences = [] for i, fileid in enumerate(walker): if i==limit: break sentence = load_librispeech_item(fileid, path=path, ext_audio=ext_audio, ext_txt=ext_txt, text_only=True) sentence = preprocess_sentence(sentence) all_sentences.append(sentence) print(" ===== They are {} transcriptions in the dataset. ===== ".format(len(all_sentences))) return tokenize(all_sentences)
def __init__( self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, ) -> None: if url in [ "dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", "train-clean-360", "train-other-500", ]: ext_archive = ".tar.gz" base_url = "http://www.openslr.org/resources/60/" url = os.path.join(base_url, url + ext_archive) # Get string representation of 'root' in case Path object is passed root = os.fspath(root) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.split(".")[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum) extract_archive(archive) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker)
def __init__(self, root, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False, preprocess=False): if url in [ "dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", "train-clean-360", "train-other-500", ]: ext_archive = ".tar.gz" base_url = "http://www.openslr.org/resources/12/" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.split(".")[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): download_url(url, root) extract_archive(archive) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker) if preprocess: self.preprocess_embeddings(self._path, self._ext_audio, self._ext_embed)
def __init__(self, tokenizer, limit=None, n_channels=1, n_frames=128, sr=16000, n_fft=2048, max_target_length=40, n_mels=40, hop_length=512, power=1.0, n_mfcc=39, duration=10, path="../librispeech/LibriSpeech/", version = "train-clean-360"): 'Initialization' self.tokenizer = tokenizer self.limit = limit self.n_channels = n_channels self.n_frames = n_frames self.sr = sr self.n_fft = n_fft self.n_mels = n_mels self.hop_length = hop_length self.power = power self.n_mfcc = n_mfcc self.duration = duration self._path = os.path.join(path, version) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker)[:limit] self.max_length = max_target_length
def __init__(self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, transform: Any = None, target_transform: Any = None) -> None: if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please remove the option `transform=True` and " "`target_transform=True` to suppress this warning.") self.transform = transform self.target_transform = target_transform # Get string representation of 'root' in case Path object is passed root = os.fspath(root) archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker)
def __init__(self, root: str, audio_folder: str, text_file: str): """ Args: root (str): root folder of the dataset audio_folder (str): folder with the audio files inside root folder text_file (str): path to the file with the text transcriptions of the audio files inside root folder """ self._root = root self._audio_folder = audio_folder walker = walk_files(root, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker) text_path = os.path.join(root, text_file) with open(text_path, "r") as text_file: text = unicode_csv_reader(text_file, delimiter="|", quoting=csv.QUOTE_NONE) self._text = list(text) # Delete first row of csv with the information about the columns self._text.pop(0) assert len(self._walker) == len(self._text), \ "Number of audiofiles is different from number of texts"
def __init__(self, root: str, audio_folder: str, text_file: str): self._root = root self._audio_folder = audio_folder self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') walker = walk_files(root, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker) text_path = os.path.join(root, text_file) with open(text_path, "r") as text_file: text = unicode_csv_reader(text_file, delimiter="|", quoting=csv.QUOTE_NONE) self._text = list(text) # Delete first row of csv with the information about the columns self._text.pop(0) assert len(self._walker) == len(self._text), \ "Number of audiofiles is different from number of texts"
def __init__(self, root: str, url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, transform: Any = None, target_transform: Any = None) -> None: if transform is not None or target_transform is not None: warnings.warn( "In the next version, transforms will not be part of the dataset. " "Please remove the option `transform=True` and " "`target_transform=True` to suppress this warning." ) self.transform = transform self.target_transform = target_transform archive = os.path.basename(url) archive = os.path.join(root, archive) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): download_url(url, root) extract_archive(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) walker = walk_files( self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True ) self._walker = list(walker)
def __init__(self, root, sample_rate: int, num_noise_to_load: int = 3, noise_to_load: list = None, num_synthetic_noise: int = 0, folder_in_archive=FOLDER_IN_ARCHIVE, url=DEMAND_JSON, download=False, transform=None): assert sample_rate in (16000, 48000) available_noise = set([ i['key'].split('_')[0] for i in url['files'] if (str(sample_rate // 1000) in i['key']) ]) self.available_noise = list(available_noise) self.available_noise.sort() self.num_noise_to_load = num_noise_to_load if noise_to_load is None: self.noise_to_load = self.available_noise[:num_noise_to_load] else: assert all([i in self.available_noise for i in noise_to_load]) self.noise_to_load = noise_to_load self.transform = transform urls_to_load = [[ i['links']['self'] for i in url['files'] if i['key'] == f'{noise}_{int(sample_rate / 1000)}k.zip' ][0] for noise in self.noise_to_load] self._path = os.path.join(root, folder_in_archive) archive_list = [ os.path.join(self._path, f'{noise}_{int(sample_rate / 1000)}k.zip') for noise in self.noise_to_load ] if download: for archive, url, data_name in zip(archive_list, urls_to_load, self.noise_to_load): if os.path.isdir(os.path.join(self._path, data_name)): continue if not os.path.isfile(archive): logging.info(f'Loading {archive}') folder_to_load = os.path.split(archive)[0] os.makedirs(folder_to_load, exist_ok=True) download_url(url, folder_to_load) extract_archive(archive) os.remove(archive) if not os.path.isdir(self._path): raise RuntimeError( "Dataset not found. Please use `download=True` to download it." ) walker = walk_files(self._path, suffix=self._ext_audio, prefix=True, remove_suffix=True) self._walker = list(walker) for i in range(num_synthetic_noise): self._walker.append(os.path.join(self._path, 'synthetic', str(i)))
def __init__(self, root: str, url: str = URL, split: str = "train", folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False) -> None: ''' :param root: 数据集的根目录 :type root: str :param url: 数据集版本,默认为v0.02 :type url: str, optional :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"`` :type split: str, optional :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"`` :type folder_in_archive: str, optional :param download: 是否下载数据,默认为False :type download: bool, optional SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。 数据集包含三大类单词的音频: #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual". #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine". #. 非关键词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow". v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。 ''' self.split = verify_str_arg(split, "split", ("train", "val", "test")) if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive, self._path) elif not os.path.isdir(self._path): raise FileNotFoundError( "Audio data not found. Please specify \"download=True\" and try again." ) if self.split == "train": record = os.path.join(self._path, TRAIN_RECORD) if os.path.exists(record): with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f]) else: print("No training list, generating...") walker = walk_files(self._path, suffix=".wav", prefix=True) walker = filter( lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) walker = map(lambda w: os.path.relpath(w, self._path), walker) walker = set(walker) val_record = os.path.join(self._path, VAL_RECORD) with open(val_record, 'r') as f: val_walker = set([line.rstrip('\n') for line in f]) test_record = os.path.join(self._path, TEST_RECORD) with open(test_record, 'r') as f: test_walker = set([line.rstrip('\n') for line in f]) walker = walker - val_walker - test_walker self._walker = list(walker) with open(record, 'w') as f: f.write('\n'.join(self._walker)) print("Training list generated!") else: if self.split == "val": record = os.path.join(self._path, VAL_RECORD) else: record = os.path.join(self._path, TEST_RECORD) with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f])
def __init__(self, path): walker = walk_files(path, suffix=self._ext_audio, prefix=False, remove_suffix=True) self._walker = list(walker)
def __init__( self, root: Union[str, Path], url: str = URL, folder_in_archive: str = FOLDER_IN_ARCHIVE, download: bool = False, subset: Optional[str] = None, ) -> None: assert subset is None or subset in [ "training", "validation", "testing" ], ("When `subset` not None, it must take a value from " + "{'training', 'validation', 'testing'}.") if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) # Get string representation of 'root' in case Path object is passed root = os.fspath(root) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive, self._path) if subset == "validation": self._walker = _load_list(self._path, "validation_list.txt") elif subset == "testing": self._walker = _load_list(self._path, "testing_list.txt") elif subset == "training": excludes = set( _load_list(self._path, "validation_list.txt", "testing_list.txt")) walker = walk_files(self._path, suffix=".wav", prefix=True) self._walker = [ w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w and os.path.normpath(w) not in excludes ] else: walker = walk_files(self._path, suffix=".wav", prefix=True) self._walker = [ w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w ]
def __init__(self, label_dict: Dict, root: str, transform: Optional[Callable] = None, url: Optional[str] = URL, split: Optional[str] = "train", folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE, download: Optional[bool] = False) -> None: ''' :param label_dict: 标签与类别的对应字典 :type label_dict: Dict :param root: 数据集的根目录 :type root: str :param transform: A function/transform that takes in a raw audio :type transform: Callable, optional :param url: 数据集版本,默认为v0.02 :type url: str, optional :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"`` :type split: str, optional :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"`` :type folder_in_archive: str, optional :param download: 是否下载数据,默认为False :type download: bool, optional SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。 数据集包含三大类单词的音频: #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual". #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine". #. 非关键词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow". v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。 代码实现基于torchaudio并扩充了功能,同时也参考了 `原论文的实现 <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_。 ''' self.split = verify_str_arg(split, "split", ("train", "val", "test")) self.label_dict = label_dict self.transform = transform if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, hash_value=checksum, hash_type="md5") extract_archive(archive, self._path) elif not os.path.isdir(self._path): raise FileNotFoundError("Audio data not found. Please specify \"download=True\" and try again.") if self.split == "train": record = os.path.join(self._path, TRAIN_RECORD) if os.path.exists(record): with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f]) else: print("No training list, generating...") walker = walk_files(self._path, suffix=".wav", prefix=True) walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) walker = map(lambda w: os.path.relpath(w, self._path), walker) walker = set(walker) val_record = os.path.join(self._path, VAL_RECORD) with open(val_record, 'r') as f: val_walker = set([line.rstrip('\n') for line in f]) test_record = os.path.join(self._path, TEST_RECORD) with open(test_record, 'r') as f: test_walker = set([line.rstrip('\n') for line in f]) walker = walker - val_walker - test_walker self._walker = list(walker) with open(record, 'w') as f: f.write('\n'.join(self._walker)) print("Training list generated!") labels = [self.label_dict.get(os.path.split(relpath)[0]) for relpath in self._walker] label_weights = 1. / np.unique(labels, return_counts=True)[1] label_weights /= np.sum(label_weights) self.weights = torch.DoubleTensor([label_weights[label] for label in labels]) else: if self.split == "val": record = os.path.join(self._path, VAL_RECORD) else: record = os.path.join(self._path, TEST_RECORD) with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f])