コード例 #1
0
ファイル: speechcommands.py プロジェクト: tuxzz/audio
    def __init__(self,
                 root,
                 url=URL,
                 folder_in_archive=FOLDER_IN_ARCHIVE,
                 download=False):
        if url in [
                "speech_commands_v0.01",
                "speech_commands_v0.02",
        ]:
            base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
            ext_archive = ".tar.gz"

            url = os.path.join(base_url, url + ext_archive)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.rsplit(".", 2)[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    download_url(url, root)
                extract_archive(archive, self._path)

        walker = walk_files(self._path, suffix=".wav", prefix=True)
        walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w,
                        walker)
        self._walker = list(walker)
コード例 #2
0
 def __init__(self):
     self.path = './data/'
     walker = walk_files(self.path,
                         suffix=self._ext_audio,
                         prefix=False,
                         remove_suffix=False)
     self._walker = list(walker)
コード例 #3
0
ファイル: eduskunta.py プロジェクト: vjoki/fsl-experi
    def __init__(self, root: str) -> None:

        self._path = os.path.join(root, 'edus80')

        walker = walk_files(
            self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
        )
        self._walker = list(walker)
コード例 #4
0
ファイル: utils_test.py プロジェクト: suraj813/audio
 def test_walk_files(self):
     """walk_files should traverse files in alphabetical order"""
     n_ites = 0
     for i, path in enumerate(dataset_utils.walk_files(self.root, '.txt', prefix=True)):
         found = os.path.join(self.root, path)
         assert found == self.expected[i]
         n_ites += 1
     assert n_ites == len(self.expected)
コード例 #5
0
ファイル: gtzan.py プロジェクト: yuvrajmetrani2/audio
    def __init__(
        self,
        root: str,
        url: str = URL,
        folder_in_archive: str = FOLDER_IN_ARCHIVE,
        download: bool = False,
        subset: Any = None,
    ) -> None:

        # super(GTZAN, self).__init__()
        self.root = root
        self.url = url
        self.folder_in_archive = folder_in_archive
        self.download = download
        self.subset = subset

        assert subset is None or subset in [
            "training", "validation", "testing"
        ], ("When `subset` not None, it must take a value from " +
            "{'training', 'validation', 'testing'}.")

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url,
                                 root,
                                 hash_value=checksum,
                                 hash_type="md5")
                extract_archive(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        if self.subset is None:
            walker = walk_files(self._path,
                                suffix=self._ext_audio,
                                prefix=False,
                                remove_suffix=True)
            self._walker = list(walker)
        else:
            if self.subset == "training":
                self._walker = filtered_train
            elif self.subset == "validation":
                self._walker = filtered_valid
            elif self.subset == "testing":
                self._walker = filtered_test
コード例 #6
0
ファイル: yesno.py プロジェクト: music-apps/pytorch-audio
    def __init__(
        self,
        root,
        url=URL,
        folder_in_archive=FOLDER_IN_ARCHIVE,
        download=False,
        transform=None,
        target_transform=None,
        return_dict=False,
    ):

        if not return_dict:
            warnings.warn(
                "In the next version, the item returned will be a dictionary. "
                "Please use `return_dict=True` to enable this behavior now, "
                "and suppress this warning.",
                DeprecationWarning,
            )

        if transform is not None or target_transform is not None:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please remove the option `transform=True` and "
                "`target_transform=True` to suppress this warning.",
                DeprecationWarning,
            )

        self.transform = transform
        self.target_transform = target_transform
        self.return_dict = return_dict

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    download_url(url, root)
                extract_archive(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        self._walker = list(walker)
    def __init__(self,
                 root: str,
                 url: str,
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False) -> None:

        if url in [
            "dev-clean",
            "dev-other",
            "test-clean",
            "test-other",
            "train-clean-100",
            "train-clean-360",
            "train-other-500",
        ]:

            ext_archive = ".tar.gz"
            base_url = "http://www.openslr.org/resources/12/"

            url = os.path.join(base_url, url + ext_archive)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.split(".")[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url, root, hash_value=checksum)
                extract_archive(archive)

            audio_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)
            for root, dirs, files in os.walk(self._path):
                if len(files) != 0:
                    for file in files:
                        if file.split('.')[-1]==self._ext_wav.split('.')[-1]:
                            file_audio = os.path.join(root, file)
                            waveform, _ = torchaudio.load(file_audio)
                            spec = audio_transforms(waveform)
                            file_spec = os.path.join(root, file.split('.')[0]+ self._ext_wav)
                            torch.save(spec, file_spec)

        walker = walk_files(
            self._path, suffix=self._ext_mel, prefix=False, remove_suffix=True
        )
        self._walker = list(walker)
コード例 #8
0
ファイル: vctk.py プロジェクト: oceanos74/audio
    def __init__(self,
                 root: str,
                 url: str = URL,
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False,
                 downsample: bool = False,
                 transform: Any = None,
                 target_transform: Any = None) -> None:

        if downsample:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please use `downsample=False` to enable this behavior now, ",
                "and suppress this warning.")

        if transform is not None or target_transform is not None:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please remove the option `transform=True` and "
                "`target_transform=True` to suppress this warning.")

        self.downsample = downsample
        self.transform = transform
        self.target_transform = target_transform

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url,
                                 root,
                                 hash_value=checksum,
                                 hash_type="md5")
                extract_archive(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        walker = filter(lambda w: self._except_folder not in w, walker)
        self._walker = list(walker)
コード例 #9
0
    def __init__(self, root_path: str) -> None:

        root_path.replace("\\", "/")

        self._path = root_path

        classes_file_path = os.path.join(root_path, CLASSES_FILE)

        walker = walk_files(self._path, suffix=".wav", prefix=True)
        walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w,
                        walker)
        self._walker = list(walker)

        with open(classes_file_path, 'rb') as handle:
            self.classes = pickle.load(handle)
コード例 #10
0
ファイル: vctk.py プロジェクト: mohammedgomaa/audio-1
    def __init__(self,
                 root: str,
                 url: str = URL,
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False,
                 downsample: bool = False,
                 transform: Any = None,
                 target_transform: Any = None) -> None:

        if downsample:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please use `downsample=False` to enable this behavior now, "
                "and suppress this warning."
            )

        if transform is not None or target_transform is not None:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please remove the option `transform=True` and "
                "`target_transform=True` to suppress this warning."
            )

        self.downsample = downsample
        self.transform = transform
        self.target_transform = target_transform

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            raise RuntimeError(
                "This Dataset is no longer available. "
                "Please use `VCTK_092` class to download the latest version."
            )

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `VCTK_092` class "
                "with `download=True` to donwload the latest version."
            )

        walker = walk_files(
            self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
        )
        walker = filter(lambda w: self._except_folder not in w, walker)
        self._walker = list(walker)
コード例 #11
0
ファイル: utils.py プロジェクト: FolabiAhn/dialectai
def tokenizer_librispeech(limit=10, path="../librispeech/LibriSpeech/", version = "train-clean-360"):
    ext_txt = ".trans.txt"
    ext_audio = ".flac"
    path = os.path.join(path, version)
    
    walker = walk_files(path, suffix=ext_audio, prefix=False, remove_suffix=True)
    walker = list(walker)
    all_sentences = []
    for i, fileid in enumerate(walker):
        if i==limit:
            break
        sentence = load_librispeech_item(fileid, path=path, ext_audio=ext_audio, ext_txt=ext_txt, text_only=True)
        sentence = preprocess_sentence(sentence)
        all_sentences.append(sentence)

    print(" ===== They are {} transcriptions in the dataset. ===== ".format(len(all_sentences)))  
    return tokenize(all_sentences)
コード例 #12
0
ファイル: libritts.py プロジェクト: zeta1999/audio
    def __init__(
        self,
        root: Union[str, Path],
        url: str = URL,
        folder_in_archive: str = FOLDER_IN_ARCHIVE,
        download: bool = False,
    ) -> None:

        if url in [
                "dev-clean",
                "dev-other",
                "test-clean",
                "test-other",
                "train-clean-100",
                "train-clean-360",
                "train-other-500",
        ]:

            ext_archive = ".tar.gz"
            base_url = "http://www.openslr.org/resources/60/"

            url = os.path.join(base_url, url + ext_archive)

        # Get string representation of 'root' in case Path object is passed
        root = os.fspath(root)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.split(".")[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url, root, hash_value=checksum)
                extract_archive(archive)

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        self._walker = list(walker)
コード例 #13
0
    def __init__(self,
                 root,
                 url=URL,
                 folder_in_archive=FOLDER_IN_ARCHIVE,
                 download=False,
                 preprocess=False):

        if url in [
                "dev-clean",
                "dev-other",
                "test-clean",
                "test-other",
                "train-clean-100",
                "train-clean-360",
                "train-other-500",
        ]:

            ext_archive = ".tar.gz"
            base_url = "http://www.openslr.org/resources/12/"

            url = os.path.join(base_url, url + ext_archive)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.split(".")[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    download_url(url, root)
                extract_archive(archive)

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        self._walker = list(walker)

        if preprocess:
            self.preprocess_embeddings(self._path, self._ext_audio,
                                       self._ext_embed)
コード例 #14
0
ファイル: utils.py プロジェクト: FolabiAhn/dialectai
 def __init__(self, tokenizer, limit=None, n_channels=1, n_frames=128, sr=16000, n_fft=2048, max_target_length=40,
              n_mels=40, hop_length=512, power=1.0, n_mfcc=39, duration=10, path="../librispeech/LibriSpeech/", version = "train-clean-360"):
     'Initialization'
     self.tokenizer = tokenizer
     self.limit = limit 
     self.n_channels = n_channels
     self.n_frames = n_frames
     self.sr = sr
     self.n_fft = n_fft
     self.n_mels = n_mels
     self.hop_length = hop_length
     self.power = power
     self.n_mfcc = n_mfcc
     self.duration = duration
     self._path = os.path.join(path, version)
     walker = walk_files(self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True)
     self._walker = list(walker)[:limit]
     self.max_length = max_target_length
コード例 #15
0
    def __init__(self,
                 root: Union[str, Path],
                 url: str = URL,
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False,
                 transform: Any = None,
                 target_transform: Any = None) -> None:

        if transform is not None or target_transform is not None:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please remove the option `transform=True` and "
                "`target_transform=True` to suppress this warning.")

        self.transform = transform
        self.target_transform = target_transform

        # Get string representation of 'root' in case Path object is passed
        root = os.fspath(root)

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url,
                                 root,
                                 hash_value=checksum,
                                 hash_type="md5")
                extract_archive(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        self._walker = list(walker)
コード例 #16
0
    def __init__(self, root: str, audio_folder: str, text_file: str):
        """
        Args:
            root (str): root folder of the dataset
            audio_folder (str): folder with the audio files inside root folder
            text_file (str): path to the file with the text transcriptions of the audio files inside
                root folder
        """
        self._root = root
        self._audio_folder = audio_folder
        walker = walk_files(root, suffix=self._ext_audio, prefix=False, remove_suffix=True)
        self._walker = list(walker)

        text_path = os.path.join(root, text_file)
        with open(text_path, "r") as text_file:
            text = unicode_csv_reader(text_file, delimiter="|", quoting=csv.QUOTE_NONE)
            self._text = list(text)
            # Delete first row of csv with the information about the columns
            self._text.pop(0)

        assert len(self._walker) == len(self._text), \
            "Number of audiofiles is different from number of texts"
コード例 #17
0
    def __init__(self, root: str, audio_folder: str, text_file: str):
        self._root = root
        self._audio_folder = audio_folder
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

        walker = walk_files(root,
                            suffix=self._ext_audio,
                            prefix=False,
                            remove_suffix=True)
        self._walker = list(walker)

        text_path = os.path.join(root, text_file)
        with open(text_path, "r") as text_file:
            text = unicode_csv_reader(text_file,
                                      delimiter="|",
                                      quoting=csv.QUOTE_NONE)
            self._text = list(text)
            # Delete first row of csv with the information about the columns
            self._text.pop(0)

        assert len(self._walker) == len(self._text), \
            "Number of audiofiles is different from number of texts"
コード例 #18
0
ファイル: yesno.py プロジェクト: jevenzh/audio
    def __init__(self,
                 root: str,
                 url: str = URL,
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False,
                 transform: Any = None,
                 target_transform: Any = None) -> None:

        if transform is not None or target_transform is not None:
            warnings.warn(
                "In the next version, transforms will not be part of the dataset. "
                "Please remove the option `transform=True` and "
                "`target_transform=True` to suppress this warning."
            )

        self.transform = transform
        self.target_transform = target_transform

        archive = os.path.basename(url)
        archive = os.path.join(root, archive)
        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    download_url(url, root)
                extract_archive(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        walker = walk_files(
            self._path, suffix=self._ext_audio, prefix=False, remove_suffix=True
        )
        self._walker = list(walker)
コード例 #19
0
    def __init__(self,
                 root,
                 sample_rate: int,
                 num_noise_to_load: int = 3,
                 noise_to_load: list = None,
                 num_synthetic_noise: int = 0,
                 folder_in_archive=FOLDER_IN_ARCHIVE,
                 url=DEMAND_JSON,
                 download=False,
                 transform=None):

        assert sample_rate in (16000, 48000)
        available_noise = set([
            i['key'].split('_')[0] for i in url['files']
            if (str(sample_rate // 1000) in i['key'])
        ])
        self.available_noise = list(available_noise)
        self.available_noise.sort()
        self.num_noise_to_load = num_noise_to_load

        if noise_to_load is None:
            self.noise_to_load = self.available_noise[:num_noise_to_load]
        else:
            assert all([i in self.available_noise for i in noise_to_load])
            self.noise_to_load = noise_to_load

        self.transform = transform

        urls_to_load = [[
            i['links']['self'] for i in url['files']
            if i['key'] == f'{noise}_{int(sample_rate / 1000)}k.zip'
        ][0] for noise in self.noise_to_load]

        self._path = os.path.join(root, folder_in_archive)
        archive_list = [
            os.path.join(self._path, f'{noise}_{int(sample_rate / 1000)}k.zip')
            for noise in self.noise_to_load
        ]

        if download:
            for archive, url, data_name in zip(archive_list, urls_to_load,
                                               self.noise_to_load):
                if os.path.isdir(os.path.join(self._path, data_name)):
                    continue
                if not os.path.isfile(archive):
                    logging.info(f'Loading {archive}')
                    folder_to_load = os.path.split(archive)[0]
                    os.makedirs(folder_to_load, exist_ok=True)
                    download_url(url, folder_to_load)
                extract_archive(archive)
                os.remove(archive)

        if not os.path.isdir(self._path):
            raise RuntimeError(
                "Dataset not found. Please use `download=True` to download it."
            )

        walker = walk_files(self._path,
                            suffix=self._ext_audio,
                            prefix=True,
                            remove_suffix=True)
        self._walker = list(walker)

        for i in range(num_synthetic_noise):
            self._walker.append(os.path.join(self._path, 'synthetic', str(i)))
コード例 #20
0
    def __init__(self,
                 root: str,
                 url: str = URL,
                 split: str = "train",
                 folder_in_archive: str = FOLDER_IN_ARCHIVE,
                 download: bool = False) -> None:
        '''
        :param root: 数据集的根目录
        :type root: str
        :param url: 数据集版本,默认为v0.02
        :type url: str, optional
        :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"``
        :type split: str, optional
        :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"``
        :type folder_in_archive: str, optional
        :param download: 是否下载数据,默认为False
        :type download: bool, optional

        SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。

        数据集包含三大类单词的音频:

        #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual".

        #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine".

        #. 非关键词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow".

        v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。
        '''

        self.split = verify_str_arg(split, "split", ("train", "val", "test"))

        if url in [
                "speech_commands_v0.01",
                "speech_commands_v0.02",
        ]:
            base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
            ext_archive = ".tar.gz"

            url = os.path.join(base_url, url + ext_archive)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.rsplit(".", 2)[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url,
                                 root,
                                 hash_value=checksum,
                                 hash_type="md5")
                extract_archive(archive, self._path)
        elif not os.path.isdir(self._path):
            raise FileNotFoundError(
                "Audio data not found. Please specify \"download=True\" and try again."
            )

        if self.split == "train":
            record = os.path.join(self._path, TRAIN_RECORD)
            if os.path.exists(record):
                with open(record, 'r') as f:
                    self._walker = list([line.rstrip('\n') for line in f])
            else:
                print("No training list, generating...")
                walker = walk_files(self._path, suffix=".wav", prefix=True)
                walker = filter(
                    lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w,
                    walker)
                walker = map(lambda w: os.path.relpath(w, self._path), walker)

                walker = set(walker)

                val_record = os.path.join(self._path, VAL_RECORD)
                with open(val_record, 'r') as f:
                    val_walker = set([line.rstrip('\n') for line in f])

                test_record = os.path.join(self._path, TEST_RECORD)
                with open(test_record, 'r') as f:
                    test_walker = set([line.rstrip('\n') for line in f])

                walker = walker - val_walker - test_walker
                self._walker = list(walker)

                with open(record, 'w') as f:
                    f.write('\n'.join(self._walker))

                print("Training list generated!")

        else:
            if self.split == "val":
                record = os.path.join(self._path, VAL_RECORD)
            else:
                record = os.path.join(self._path, TEST_RECORD)
            with open(record, 'r') as f:
                self._walker = list([line.rstrip('\n') for line in f])
コード例 #21
0
 def __init__(self, path):
     walker = walk_files(path,
                         suffix=self._ext_audio,
                         prefix=False,
                         remove_suffix=True)
     self._walker = list(walker)
コード例 #22
0
    def __init__(
        self,
        root: Union[str, Path],
        url: str = URL,
        folder_in_archive: str = FOLDER_IN_ARCHIVE,
        download: bool = False,
        subset: Optional[str] = None,
    ) -> None:

        assert subset is None or subset in [
            "training", "validation", "testing"
        ], ("When `subset` not None, it must take a value from " +
            "{'training', 'validation', 'testing'}.")

        if url in [
                "speech_commands_v0.01",
                "speech_commands_v0.02",
        ]:
            base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
            ext_archive = ".tar.gz"

            url = os.path.join(base_url, url + ext_archive)

        # Get string representation of 'root' in case Path object is passed
        root = os.fspath(root)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.rsplit(".", 2)[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url,
                                 root,
                                 hash_value=checksum,
                                 hash_type="md5")
                extract_archive(archive, self._path)

        if subset == "validation":
            self._walker = _load_list(self._path, "validation_list.txt")
        elif subset == "testing":
            self._walker = _load_list(self._path, "testing_list.txt")
        elif subset == "training":
            excludes = set(
                _load_list(self._path, "validation_list.txt",
                           "testing_list.txt"))
            walker = walk_files(self._path, suffix=".wav", prefix=True)
            self._walker = [
                w for w in walker
                if HASH_DIVIDER in w and EXCEPT_FOLDER not in w
                and os.path.normpath(w) not in excludes
            ]
        else:
            walker = walk_files(self._path, suffix=".wav", prefix=True)
            self._walker = [
                w for w in walker
                if HASH_DIVIDER in w and EXCEPT_FOLDER not in w
            ]
コード例 #23
0
    def __init__(self,
                 label_dict: Dict,
                 root: str,
                 transform: Optional[Callable] = None,
                 url: Optional[str] = URL,
                 split: Optional[str] = "train",
                 folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE,
                 download: Optional[bool] = False) -> None:
        '''
        :param label_dict: 标签与类别的对应字典
        :type label_dict: Dict
        :param root: 数据集的根目录
        :type root: str
        :param transform: A function/transform that takes in a raw audio
        :type transform: Callable, optional
        :param url: 数据集版本,默认为v0.02
        :type url: str, optional
        :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"``
        :type split: str, optional
        :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"``
        :type folder_in_archive: str, optional
        :param download: 是否下载数据,默认为False
        :type download: bool, optional

        SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。

        数据集包含三大类单词的音频:

        #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual".

        #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine".

        #. 非关键词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow".

        v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。

        代码实现基于torchaudio并扩充了功能,同时也参考了 `原论文的实现 <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_。
        '''

        self.split = verify_str_arg(split, "split", ("train", "val", "test"))
        self.label_dict = label_dict
        self.transform = transform
        
        if url in [
            "speech_commands_v0.01",
            "speech_commands_v0.02",
        ]:
            base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
            ext_archive = ".tar.gz"

            url = os.path.join(base_url, url + ext_archive)

        basename = os.path.basename(url)
        archive = os.path.join(root, basename)

        basename = basename.rsplit(".", 2)[0]
        folder_in_archive = os.path.join(folder_in_archive, basename)

        self._path = os.path.join(root, folder_in_archive)

        if download:
            if not os.path.isdir(self._path):
                if not os.path.isfile(archive):
                    checksum = _CHECKSUMS.get(url, None)
                    download_url(url, root, hash_value=checksum, hash_type="md5")
                extract_archive(archive, self._path)
        elif not os.path.isdir(self._path):
            raise FileNotFoundError("Audio data not found. Please specify \"download=True\" and try again.")


        if self.split == "train":
            record = os.path.join(self._path, TRAIN_RECORD)
            if os.path.exists(record):
                with open(record, 'r') as f:
                    self._walker = list([line.rstrip('\n') for line in f])
            else:
                print("No training list, generating...")
                walker = walk_files(self._path, suffix=".wav", prefix=True)
                walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
                walker = map(lambda w: os.path.relpath(w, self._path), walker)

                walker = set(walker)

                val_record = os.path.join(self._path, VAL_RECORD)
                with open(val_record, 'r') as f:
                    val_walker = set([line.rstrip('\n') for line in f])

                test_record = os.path.join(self._path, TEST_RECORD)
                with open(test_record, 'r') as f:
                    test_walker = set([line.rstrip('\n') for line in f])

                walker = walker - val_walker - test_walker
                self._walker = list(walker)

                with open(record, 'w') as f:
                    f.write('\n'.join(self._walker))

                print("Training list generated!")

            labels = [self.label_dict.get(os.path.split(relpath)[0]) for relpath in self._walker]
            label_weights = 1. / np.unique(labels, return_counts=True)[1]
            label_weights /= np.sum(label_weights)
            self.weights = torch.DoubleTensor([label_weights[label] for label in labels])

        else:
            if self.split == "val":
                record = os.path.join(self._path, VAL_RECORD)
            else:
                record = os.path.join(self._path, TEST_RECORD)
            with open(record, 'r') as f:
                self._walker = list([line.rstrip('\n') for line in f])