def __getitem__(self, index):
        filename = self.data_path[index]
        n_fft = 128
        #fbins = n_fft//2 + 1
        spec_transform = transforms.Spectrogram(n_fft = n_fft, normalized = False)

        label = int(filename.split("/")[-1].split("_")[0])
        soundSource = filename.split("/")[-1].split("_")[1]
        number = filename.split("/")[-1].split("_")[2]

        wave, sample_rate = torchaudio.load_wav(filename)

        spec = spec_transform(wave)

        log_spec = (spec + 1e-9).log2()[0, :, :]
        

        width = 65
        height = log_spec.shape[0]
        dim = (width, height)
        log_spec = cv2.resize(log_spec.numpy(), dim, interpolation = cv2.INTER_AREA)
        plt.figure()
        plt.imshow(log_spec)
        plt.show()
        

        return log_spec, label, soundSource
示例#2
0
def transform(filename):
    if T.__version__ > '0.7.0':
        audio, sr = T.load(filename)
        audio = torch.clamp(audio[0], -1.0, 1.0)
    else:
        audio, sr = T.load_wav(filename)
        audio = torch.clamp(audio[0] / 32767.5, -1.0, 1.0)

    if params.sample_rate != sr:
        raise ValueError(f'Invalid sample rate {sr}.')
    mel_args = {
        'sample_rate': sr,
        'win_length': params.hop_samples * 4,
        'hop_length': params.hop_samples,
        'n_fft': params.n_fft,
        'f_min': 20.0,
        'f_max': sr / 2.0,
        'n_mels': params.n_mels,
        'power': 1.0,
        'normalized': True,
    }
    mel_spec_transform = TT.MelSpectrogram(**mel_args)

    with torch.no_grad():
        spectrogram = mel_spec_transform(audio)
        spectrogram = 20 * torch.log10(torch.clamp(spectrogram, min=1e-5)) - 20
        spectrogram = torch.clamp((spectrogram + 100) / 100, 0.0, 1.0)
        np.save(f'{filename}.spec.npy', spectrogram.cpu().numpy())
示例#3
0
 def __call__(self, batch):
     mean_stat = torch.zeros(self.feat_dim)
     var_stat = torch.zeros(self.feat_dim)
     number = 0
     for item in batch:
         value = item[1].strip().split(",")
         assert len(value) == 3 or len(value) == 1
         wav_path = value[0]
         sample_rate = torchaudio.backend.sox_backend.info(wav_path)[0].rate
         # len(value) == 3 means segmented wav.scp,
         # len(value) == 1 means original wa.scp
         if len(value) == 3:
             start_frame = int(float(value[1]) * sample_rate)
             end_frame = int(float(value[2]) * sample_rate)
             waveform, sample_rate = torchaudio.backend.sox_backend.load(
                 filepath=wav_path,
                 num_frames=end_frame - start_frame,
                 offset=start_frame)
             waveform = waveform * (1 << 15)
         else:
             waveform, sample_rate = torchaudio.load_wav(item[1])
         mat = kaldi.fbank(waveform,
                           num_mel_bins=self.feat_dim,
                           dither=0.0,
                           energy_floor=0.0)
         mean_stat += torch.sum(mat, axis=0)
         var_stat += torch.sum(torch.square(mat), axis=0)
         number += mat.shape[0]
     return number, mean_stat, var_stat
示例#4
0
文件: dataset.py 项目: entn-at/wenet
def _load_wav_with_speed(wav_file, speed):
    """ Load the wave from file and apply speed perpturbation

    Args:
        wav_file: input feature, T * F 2D

    Returns:
        augmented feature
    """
    if speed == 1.0:
        return torchaudio.load_wav(wav_file)
    else:
        si, _ = torchaudio.info(wav_file)

        # get torchaudio version
        ta_no = torchaudio.__version__.split(".")
        ta_version = 100 * int(ta_no[0]) + 10 * int(ta_no[1])

        if ta_version < 80:
            # Note: deprecated in torchaudio>=0.8.0
            E = sox_effects.SoxEffectsChain()
            E.append_effect_to_chain('speed', speed)
            E.append_effect_to_chain("rate", si.rate)
            E.set_input_file(wav_file)
            wav, sr = E.sox_build_flow_effects()
        else:
            # Note: enable in torchaudio>=0.8.0
            wav, sr = sox_effects.apply_effects_file(
                wav_file,
                [['speed', str(speed)], ['rate', str(si.rate)]])

        # sox will normalize the waveform, scale to [-32768, 32767]
        wav = wav * (1 << 15)
        return wav, sr
示例#5
0
def process_data():
    base_loc = DATA_DIR + '/raw/speech_commands/raw_speech_commands_data'
    X = torch.empty(34975, 16000, 1)
    y = torch.empty(34975, dtype=torch.long)

    batch_index = 0
    y_index = 0
    for foldername in ('yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go'):
        loc = base_loc +'/'+ foldername
        for filename in os.listdir(loc):
            audio, _ = torchaudio.load_wav(loc + '/' + filename, channels_first=False,
                                           normalization=False)  # for forward compatbility if they fix it
            audio = audio / 2 ** 15  # Normalization argument doesn't seem to work so we do it manually.

            # A few samples are shorter than the full length; for simplicity we discard them.
            if len(audio) != 16000:
                continue

            X[batch_index] = audio
            y[batch_index] = y_index
            batch_index += 1
        y_index += 1
    assert batch_index == 34975, "batch_index is {}".format(batch_index)

    # X is of shape (batch=34975, length=16000, channels=1)
    X = torchaudio.transforms.MFCC(log_mels=True,
                                   melkwargs=dict(n_fft=100, n_mels=32), n_mfcc=10)(X.squeeze(-1)).transpose(1, 2).detach()
    # X is of shape (batch=34975, length=321, channels=10). For some crazy reason it requires a gradient, so detach.

    train_X, val_X, test_X = split_data(X, y)
    train_y, val_y, test_y = split_data(y, y)

    return train_X, val_X, test_X, train_y, val_y, test_y
示例#6
0
def preprocess(pkl_path='timit_tokenized.pkl'):
    phoneme_samples = {
        'TRAIN': {x: defaultdict(list) for x in dialects},
        'TEST': {x: defaultdict(list) for x in dialects}
    }

    with torch.no_grad():
        for dataset in ['TRAIN', 'TEST']:
            for dialect in dialects:
                speakers = glob(f'TIMIT/{dataset}/{dialect}/M*')
                for speaker in tqdm(speakers):
                    sentences = set(path.split('/')[-1][:-4] for path in glob(speaker + '/*'))
                    for sentence in sentences:
                        current_path = f'TIMIT/{dataset}/{dialect}/{speaker.split("/")[-1]}/{sentence}'
                        sample, _ = torchaudio.load_wav(current_path + '.WAV')
                        df_sample = pandas.read_csv(current_path + '.PHN', sep=' ', names=phoneme_cols)
                        for _, row in df_sample.iterrows():
                            subsample = sample[:, row[0]:row[1]]
                            phn_len = row[1] - row[0]
                            if phn_len > max_len:
                                continue
                            mspec_flat_tensor = wav_to_padded_mspec_flat_tensor(subsample, phn_len)
                            phoneme_samples[dataset][dialect][row[2]].append(
                                torch.log(mspec_flat_tensor + epsilon)
                            )

                for phoneme in phoneme_samples[dataset][dialect]:
                    phoneme_samples[dataset][dialect][phoneme] = \
                        torch.cat(phoneme_samples[dataset][dialect][phoneme]).data.numpy()

    with open(pkl_path, 'wb') as f:
        pickle.dump(phoneme_samples, f)
示例#7
0
    def additive_noise(self, noisecat, audio):

        clean_db = 10 * numpy.log10(numpy.mean(audio**2) + 1e-4)

        numnoise = self.numnoise[noisecat]
        noiselist = random.sample(self.noiselist[noisecat],
                                  random.randint(numnoise[0], numnoise[1]))

        noises = []
        audio_length = audio.shape[1]
        for noise in noiselist:
            noiseaudio, sample_rate = torchaudio.load_wav(noise)
            noiseaudio = noiseaudio.detach().numpy()
            noise_length = noiseaudio.shape[1]
            if noise_length <= audio_length:
                shortage = audio_length - noise_length + 1
                noiseaudio = numpy.pad(noiseaudio, ((0, 0), (0, shortage)),
                                       'wrap')
                noiseaudio = noiseaudio[:, :audio_length]
            else:
                startframe = numpy.int64(random.random() *
                                         (noise_length - audio_length))
                noiseaudio = noiseaudio[:,
                                        int(startframe):int(startframe) +
                                        audio_length]
            noise_snr = random.uniform(self.noisesnr[noisecat][0],
                                       self.noisesnr[noisecat][1])
            noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0]**2) + 1e-4)
            noises.append(
                numpy.sqrt(10**((clean_db - noise_db - noise_snr) / 10)) *
                noiseaudio)
        noise_sum = numpy.sum(numpy.concatenate(noises, axis=0),
                              axis=0,
                              keepdims=True)
        return noise_sum + audio
示例#8
0
    def __getitem__(self, index):
        utt_id, path = self.file_list[index]

        if self.from_kaldi:
            feature = kio.load_mat(path)
        else:
            wavform, sample_frequency = ta.load_wav(path)
            feature = compute_fbank(wavform, num_mel_bins=self.params['num_mel_bins'], sample_frequency=sample_frequency)

        if self.params['apply_cmvn']:
            spk_id = self.utt2spk[utt_id]
            stats = kio.load_mat(self.cmvns[spk_id])
            feature = apply_cmvn(feature, stats)

        if self.params['normalization']:
            feature = normalization(feature)
            
        if self.params['spec_argument']:
            try:
                feature = spec_augment(feature)
            except:
                pass

        if self.left_frames > 0 or self.right_frames > 0:
            feature = concat_and_subsample(feature, left_frames=self.left_frames,
                                           right_frames=self.right_frames, skip_frames=self.skip_frames)

        feature_length = feature.shape[0]
        targets = self.targets_dict[utt_id]
        targets_length = len(targets)

        return utt_id, feature, feature_length, targets, targets_length
示例#9
0
    def __init__(self,
                 root: str,
                 training: bool = True,
                 return_length: bool = False,
                 transform=None):
        self.data = []
        self.return_length = return_length
        self.transform = transform

        self.training = training
        self.filenames = []

        if training:
            df_labels = pd.read_csv(root + "train_label.csv")
            root = root + "Train/"
            self.labels = []
        else:
            root = root + "Public_Test/"

        for filename in os.listdir(root):
            if filename.endswith(".wav"):
                self.filenames.append(filename)
                input_audio, sample_rate = load_wav(root + filename)

                self.data.append(input_audio)
                if training:
                    self.labels.append(
                        df_labels.loc[df_labels["File"] == filename,
                                      "Label"].values.item())
示例#10
0
    def __getitem__(self, index):
        utt_id, path = self.file_list[index]

        if self.from_kaldi:
            feature = kio.load_mat(path)
        else:
            wavform, sample_frequency = ta.load_wav(path)
            feature = compute_fbank(wavform,
                                    num_mel_bins=self.params['num_mel_bins'],
                                    sample_frequency=sample_frequency,
                                    dither=0.0)

        if self.params['apply_cmvn']:
            spk_id = self.utt2spk[utt_id]
            stats = kio.load_mat(self.cmvns[spk_id])
            feature = apply_cmvn(feature, stats)

        if self.params['normalization']:
            feature = normalization(feature)

        if self.apply_spec_augment:
            feature = spec_augment(feature)

        feature_length = feature.shape[0]
        targets = self.targets_dict[utt_id]
        targets_length = len(targets)

        return utt_id, feature, feature_length, targets, targets_length
示例#11
0
    def __getitem__(self, index):
        # index=None
        import torchaudio
        import torchaudio.compliance.kaldi as kaldi
        # from . import kaldi as kaldi

        tgt_item = self.tgt[index] if self.tgt is not None else None
        print(index)
        path = self.aud_paths[index]
        if not os.path.exists(path):
            raise FileNotFoundError("Audio file not found: {}".format(path))
        sound, sample_rate = torchaudio.load_wav(path)
        output = kaldi.fbank(sound,
                             num_mel_bins=self.num_mel_bins,
                             frame_length=self.frame_length,
                             frame_shift=self.frame_shift)

        output_cmvn = data_utils.apply_mv_norm(output)
        self.s2s_collater = Seq2SeqCollater(0,
                                            1,
                                            pad_index=self.tgt_dict.pad(),
                                            eos_index=self.tgt_dict.eos(),
                                            move_eos_to_beginning=True)

        return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
示例#12
0
    def __getitem__(self, index):
        import torchaudio
        import torchaudio.compliance.kaldi as kaldi
        tgt_item = self.tgt[index] if self.tgt is not None else None

        path = self.aud_paths[index]
        if not os.path.exists(path):
            raise FileNotFoundError("Audio file not found: {}".format(path))
        
        vid_data = self.load_video(index)
        sound, sample_rate = torchaudio.load_wav(path)
        
        if self.video_offset > 0: # positive offset - audio and video
            padding_frame = np.zeros([self.video_offset, np.shape(vid_data)[1]], dtype='float32')
            vid_data = np.concatenate((padding_frame,vid_data),axis=0)
        elif self.video_offset < 0: # negativte offset - video and audio
            padding_frame = np.zeros([abs(self.video_offset), np.shape(vid_data)[1]], dtype='float32')
            vid_data = np.concatenate((vid_data, padding_frame),axis=0)
            aud_padding_size = int(abs(self.video_offset) * 40 * sample_rate * 0.001)
            aud_padding = torch.zeros_like(sound)[:,0:aud_padding_size]
            sound = torch.cat((aud_padding, sound), 1)

        output = kaldi.fbank(
            sound,
            num_mel_bins=self.num_mel_bins,
            frame_length=self.frame_length,
            frame_shift=self.frame_shift
        )
        output_cmvn = data_utils.apply_mv_norm(output)

        return {"id": index, "audio_data": [output_cmvn.detach(), tgt_item], "video_data": [vid_data, tgt_item]}
示例#13
0
 def __getitem__(self, i):
     path = self.files[i]
     basename = os.path.basename(path)
     info = self.info[basename]
     x, sr = torchaudio.load_wav(path)
     x = normalize(x)
     y = make_target_vec(x.shape[1], info, sr)
     return x, y
示例#14
0
 def __getitem__(self, idx):
     audio_filename = self.filenames[idx]
     signal, _ = torchaudio.load_wav(audio_filename)
     return {
         "audio": short2float(signal[0]),
         "speaker": speaker_map[Path(audio_filename).parent.stem]
         # "speaker": self.resolve_speaker(audio_filename)
     }
示例#15
0
def load_wav(filename):
    audio, sr = T.load_wav(filename)

    if Fs != sr:
        audio = T.transforms.Resample(orig_freq=sr, new_freq=Fs)(audio)

    audio = torch.clamp(audio[0] / 32767.5, -1.0, 1.0)
    return audio, Fs
def get_fu(path_ = 'temp.wav'):
    _wavform, _ = ta.load_wav( path_ )
    _feature = ta.compliance.kaldi.fbank(_wavform, num_mel_bins=40) 
    _mean = torch.mean(_feature)
    _std = torch.std(_feature)
    _T_feature =  (_feature - _mean) / _std
    inst_T = _T_feature.unsqueeze(0)
    return inst_T
示例#17
0
 def get_fu(self):
     _wavform, _ = ta.load_wav(self.filename)
     _feature = ta.compliance.kaldi.fbank(_wavform, num_mel_bins=40)
     _mean = torch.mean(_feature)
     _std = torch.std(_feature)
     _T_feature = (_feature - _mean) / _std
     inst_T = _T_feature.unsqueeze(0)
     return inst_T
示例#18
0
    def load_audio(self, audio_file):
        """ Loads audio wav file into torch tensor.

            Args:
                audio_file (string).
            Returns:
                Torch tensor of shape [1 x 16000] (audio files are sampled at 16 kHz).
        """
        return torchaudio.load_wav("../data/" + audio_file + ".wav")[0]
def load_audio(path):
    sound, _ = torchaudio.load_wav(path)
    sound = sound.numpy()
    if sound.shape[0] == 1:
        sound = sound.squeeze()
    else:
        sound = sound.mean(axis=0)  # multiple channels, average
    sound = sound / (2**15)
    return sound
示例#20
0
 def __getitem__(self, idx):
   audio_filename = self.filenames[idx]
   spec_filename = f'{audio_filename}.spec.npy'
   signal, _ = torchaudio.load_wav(audio_filename)
   spectrogram = np.load(spec_filename)
   return {
       'audio': signal[0] / 32767.5,
       'spectrogram': spectrogram.T
   }
示例#21
0
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    #use_cuda = torch.cuda.is_available() and not args.cpu
    # use_cuda = False

    # Load dataset splits
    task = tasks.setup_task(args)

    # Set dictionary
    tgt_dict = task.target_dictionary

    if args.ctc or args.rnnt:
        tgt_dict.add_symbol("<ctc_blank>")
        if args.ctc:
            logger.info("| decoding a ctc model")
        if args.rnnt:
            logger.info("| decoding a rnnt model")

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(":"),
        task,
        model_arg_overrides=eval(args.model_overrides),  # noqa
    )
    optimize_models(args, models)

    # Initialize generator
    generator = task.build_generator(args)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, 'spm.model'))

    # TODO: replace this
    # path = '/Users/jamarshon/Downloads/snippet.mp3'
    # path = '/Users/jamarshon/Downloads/hamlet.mp3'
    path = '/home/aakashns/speech_transcribe/deepspeech.pytorch/data/an4_dataset/train/an4/wav/cen8-mwhw-b.wav'
    if not os.path.exists(path):
        raise FileNotFoundError("Audio file not found: {}".format(path))
    waveform, sample_rate = torchaudio.load_wav(path)
    waveform = waveform.mean(0, True)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
                                              new_freq=16000)(waveform)
    # waveform = waveform[:, :16000*30]
    # torchaudio.save('/Users/jamarshon/Downloads/hello.wav', waveform >> 16, 16000)
    import time
    print(sample_rate, waveform.shape)
    start = time.time()
    transcribe(waveform, args, task, generator, models, sp, tgt_dict)
    end = time.time()
    print(end - start)
示例#22
0
 def __getitem__(self, idx):
     audio_filename = self.filenames[idx]
     spec_filename = f'{audio_filename}.spec.npy'
     if torchaudio.__version__ > '0.7.0':
         signal, _ = torchaudio.load(audio_filename)
     else:
         signal, _ = torchaudio.load_wav(audio_filename)
     out = signal[
         0] if torchaudio.__version__ > '0.7.0' else signal[0] / 32767.5
     return {'audio': out, 'spectrogram': None}
def preprocessing(audio_filename: str, word: str):
    waveform, _ = torchaudio.load_wav(audio_filename)
    label = text_transform.one_hot_enc(word).transpose(1, 0)

    # spectrogram
    specgram = torchaudio.transforms.MelSpectrogram()(waveform)
    specgram = F.interpolate(specgram, size=len(word),
                             mode="nearest").transpose(1, 2)

    return specgram.unsqueeze(0), label.unsqueeze(0)
示例#24
0
    def _compliance_test_helper(self,
                                sound_filepath,
                                filepath_key,
                                expected_num_files,
                                expected_num_args,
                                get_output_fn,
                                atol=1e-5,
                                rtol=1e-7):
        """
        Inputs:
            sound_filepath (str): The location of the sound file
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
            atol (float): absolute tolerance
            rtol (float): relative tolerance
        """
        sound, sr = torchaudio.load_wav(sound_filepath)
        files = self.test_filepaths[filepath_key]

        assert len(files) == expected_num_files, (
            'number of kaldi %s file changed to %d' %
            (filepath_key, len(files)))

        for f in files:
            print(f)

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
            kaldi_output_dict = {
                k: v
                for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)
            }

            assert len(
                kaldi_output_dict
            ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

            # Construct the same configuration used by kaldi
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
            assert len(
                args) == expected_num_args, 'invalid test kaldi file name'
            args = [compliance_utils.parse(arg) for arg in args]

            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
            torch.testing.assert_allclose(output,
                                          kaldi_output,
                                          atol=atol,
                                          rtol=rtol)
示例#25
0
def main():
    filepath = 'data/timit_test/DR1/FAKS0/SA1.WAV'

    assert os.path.isfile(filepath)

    data, fs = torchaudio.load_wav(filepath)

    print(data)
    print(type(data))
    print(data.shape)
    print(fs)
示例#26
0
    def predict(self, wav, num_mel_bins=440, apply_normalize=True):
        wavform, _ = ta.load_wav(wav)
        feature = compute_fbank(wavform, num_mel_bins=num_mel_bins)

        if apply_normalize:
            feature = normalization(feature)

        feature = feature.unsqueeze(0)
        feature_length = torch.LongTensor(feature.size(1), device=feature.device)

        return self.recognize(feature, feature_length)[0]
示例#27
0
 def __getitem__(self, idx):
     audio_fn = self.audio_list[idx]
     if self.verbose:
         print(f"Loading audio file {audio_fn}")
     waveform, sample_rate = torchaudio.load_wav(audio_fn)
     if self.transform:
         waveform = self.transform(waveform)
     sample = {
         'input': waveform,
         'digit': int(os.path.basename(audio_fn).split("_")[0])
     }
     return sample
示例#28
0
 def __getitem__(self, idx):
     audio_filename = self.filenames[idx]
     spec_filename = f'{audio_filename}.spec.npy'
     if torchaudio.__version__ > '0.7.0':
         signal, _ = torchaudio.load(audio_filename)
     else:
         signal, _ = torchaudio.load_wav(audio_filename)
     spectrogram = np.load(spec_filename)
     # https://github.com/lmnt-com/diffwave/issues/15
     out = signal[
         0] if torchaudio.__version__ > '0.7.0' else signal[0] / 32767.5
     return {'audio': out, 'spectrogram': spectrogram.T}
示例#29
0
def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)

    # Set dictionary
    tgt_dict = task.target_dictionary

    if args.ctc or args.rnnt:
        tgt_dict.add_symbol("<ctc_blank>")
        if args.ctc:
            logger.info("| decoding a ctc model")
        if args.rnnt:
            logger.info("| decoding a rnnt model")

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
    models, _model_args = utils.load_ensemble_for_inference(
        args.path.split(":"),
        task,
        model_arg_overrides=eval(args.model_overrides),  # noqa
    )
    optimize_models(args, use_cuda, models)

    # Initialize generator
    generator = task.build_generator(args)

    sp = spm.SentencePieceProcessor()
    sp.Load(os.path.join(args.data, 'spm.model'))

    path = args.input_file
    if not os.path.exists(path):
        raise FileNotFoundError("Audio file not found: {}".format(path))
    waveform, sample_rate = torchaudio.load_wav(path)
    waveform = waveform.mean(0, True)
    waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
                                              new_freq=16000)(waveform)
    import time
    print(sample_rate, waveform.shape)
    start = time.time()
    transcribe(waveform, args, task, generator, models, sp, tgt_dict)
    end = time.time()
    print(end - start)
def _process_data():
    base_loc = here / '..' / 'experiments' / 'data' / 'SpeechCommands'
    X = torch.empty(34975, 16000, 1)
    y = torch.empty(34975, dtype=torch.long)

    batch_index = 0
    y_index = 0
    for foldername in ('yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off',
                       'stop', 'go'):
        loc = base_loc / foldername
        for filename in os.listdir(loc):
            audio, _ = torchaudio.load_wav(
                loc / filename, channels_first=False,
                normalization=False)  # for forward compatbility if they fix it
            audio = audio / 2**15  # Normalization argument doesn't seem to work so we do it manually.

            # A few samples are shorter than the full length; for simplicity we discard them.
            if len(audio) != 16000:
                continue

            X[batch_index] = audio
            y[batch_index] = y_index
            batch_index += 1
        y_index += 1
    assert batch_index == 34975, "batch_index is {}".format(batch_index)

    audio_X = X

    # X is of shape (batch=34975, length=16000, channels=1)
    X = torchaudio.transforms.MFCC(log_mels=True)(X.squeeze(-1)).transpose(
        1, 2).detach()
    # X is of shape (batch=34975, length=81, channels=40). For some crazy reason it requires a gradient, so detach.

    train_X, _, _ = _split_data(X, y)
    out = []
    means = []
    stds = []
    for Xi, train_Xi in zip(X.unbind(dim=-1), train_X.unbind(dim=-1)):
        mean = train_Xi.mean()
        std = train_Xi.std()
        means.append(mean)
        stds.append(std)
        out.append((Xi - mean) / (std + 1e-5))
    X = torch.stack(out, dim=-1)

    train_audio_X, val_audio_X, test_audio_X = _split_data(audio_X, y)
    train_X, val_X, test_X = _split_data(X, y)
    train_y, val_y, test_y = _split_data(y, y)

    return train_X, val_X, test_X, train_y, val_y, test_y, torch.stack(means), torch.stack(stds), train_audio_X, \
           val_audio_X, test_audio_X