示例#1
0
 def compute_stft(waveform: Tensor, n_fft: int, win_length: int,
                  hop_length: int) -> Tuple[Tensor, Tensor]:
     device = waveform.device
     spectrogram = _transf.Spectrogram(n_fft=n_fft,
                                       win_length=win_length,
                                       hop_length=hop_length,
                                       power=None).to(device)
     amplitude, phase = _func.magphase(spectrogram(waveform))
     return amplitude, phase
    def forward(self, x):
        with torch.no_grad():
            x = self.format_in(x)
            spec = stft(x, int(self.fs * self.win_len),
                        int(self.fs * self.hop_len))
            mag, pha = magphase(spec)
            mag = mag.unsqueeze(1)

        mask = self.unet(mag, 'mask')
        mag_masked = (mag * mask).squeeze()
        spec_processed = self.recover_spec(mag_masked, pha)
        denoised_wav = istft(spec_processed, int(self.fs * self.win_len),
                             int(self.fs * self.hop_len))
        return denoised_wav
示例#3
0
def show_results(wav_dir):
    denoised_path, noisy_path, clean_path, noisy_denoised_path = [pjoin(wav_dir, stream)
                                                                  for stream in
                                                                  ['denoised.wav', 'noisy.wav', 'clean.wav',
                                                                   'noisy_denoised.wav']]

    denoised_spec, noisy_spec, clean_spec, noisy_denoised_spec = \
        [torch.log(magphase(torch.stft(torchaudio.load(path, normalization=True)[0], win_length=400, hop_length=160, n_fft=400))[0] + 1e-5).squeeze()
         for path in
         [denoised_path, noisy_path, clean_path, noisy_denoised_path]]

    f_spec, t_spec = denoised_spec.shape

    fig = plt.figure()
    plt.subplot(411)
    plt.pcolormesh(range(t_spec), range(f_spec), noisy_spec, shading='auto')
    plt.title('Noisy')
    plt.subplot(412)
    plt.pcolormesh(range(t_spec), range(f_spec), clean_spec, shading='auto')
    plt.title('Clean')
    plt.subplot(413)
    plt.pcolormesh(range(t_spec), range(f_spec), denoised_spec, shading='auto')
    plt.title('Denoised')
    plt.subplot(414)
    plt.pcolormesh(range(t_spec), range(f_spec), noisy_denoised_spec, shading='auto')
    plt.title('Noisy => Denoised')

    playaudio_button_position_noisy = [0.8, 0.84, 0.1, 0.075]
    pa_noisy = PlayAudio(wav_dir, feature_name='noisy', button_position=playaudio_button_position_noisy)
    pa_noisy.set_button()

    playaudio_button_position_clean = [0.8, 0.6, 0.1, 0.075]
    pa_clean = PlayAudio(wav_dir, feature_name='clean', button_position=playaudio_button_position_clean)
    pa_clean.set_button()

    playaudio_button_position_denoised = [0.8, 0.35, 0.1, 0.075]
    pa_denoised = PlayAudio(wav_dir, feature_name='denoised', button_position=playaudio_button_position_denoised)
    pa_denoised.set_button()

    playaudio_button_position_noisy_denoised = [0.8, 0.11, 0.1, 0.075]
    pa_noisy_denoised = PlayAudio(wav_dir, feature_name='noisy_denoised', button_position=playaudio_button_position_noisy_denoised)
    pa_noisy_denoised.set_button()

    fig.tight_layout()

    plt.show()
def slice_signal(path, win_len, hop_len, win_frames, hop_frames,
                 sampling_rate):
    slices = []
    sr, wavform = wavread(path)
    assert sampling_rate == sr
    wavform = torch.from_numpy(normalize_wave_minmax(wavform))
    stft_complex = torch.stft(wavform, win_len, hop_len)
    # stft_mag_orig, stft_pha_orig = stft_complex[:, :, 0].numpy(), stft_complex[:, :, 1].numpy()
    mag, pha = magphase(stft_complex)
    mag = torch.log(mag + 1e-7)
    # stft_mag = in_mag_scale(mag)
    # stft_pha = in_pha_scale(stft_pha_orig)

    # print(np.max(np.abs(stft_mag_recover - stft_mag_orig)))
    # assert stft_mag_recover.all() == stft_mag_orig.all()
    # assert stft_pha_recover.all() == stft_pha_orig.all()
    # stft_mag_recover = inverse_in_mag_scale(stft_mag)
    # stft_pha_recover = inverse_in_pha_scale(stft_pha)
    #
    # stft_recover = np.stack([stft_mag_recover, stft_pha_recover], axis=-1)
    # signal_recover = torch.istft(torch.from_numpy(stft_recover), n_fft=400, hop_length=160)
    # wavwrite('./recover.wav', 16000, signal_recover.numpy())
    # stft_orig = np.stack([stft_mag_orig, stft_pha_orig], axis=-1)
    # signal_orig = torch.istft(torch.from_numpy(stft_orig), n_fft=400, hop_length=160)
    # wavwrite('./orig.wav', 16000, signal_orig.numpy())

    len_frames = stft_complex.size()[-2]
    num_slices = math.floor((len_frames - win_frames) / hop_frames) + 1
    if num_slices > 0:
        for idx_slice in range(num_slices):
            slices.append([
                mag[:, idx_slice * hop_frames:idx_slice * hop_frames +
                    win_frames],
                pha[:,
                    idx_slice * hop_frames:idx_slice * hop_frames + win_frames]
            ])
            # slices_pha.append(stft_pha[:, idx_slice * hop_frames : idx_slice * hop_frames + win_frames].numpy())
    return slices
    def stream_processor(self,
                         audio_wavs,
                         feature='time-domain',
                         sr=16000,
                         win_size=None,
                         hop_size=None):
        """

        :param signals: a list of signals, i.e, [signal, noise, mixed, target, preserve]
        :param feature: features to be extracted for the above signals, such as raw, stft => spectrogram, phase, mfcc
        :return:
        """

        # general parameters for stft-related extracting
        hop_length = int(sr * hop_size)
        win_length = int(sr * win_size)
        n_fft = int(sr * win_size)
        n_mels = 128
        n_mfcc = 40

        def feature_converter(feature):
            if feature.endswith('spectrogram'):
                converter = torchaudio.transforms.Spectrogram(
                    n_fft=n_fft,
                    win_length=win_length,
                    hop_length=hop_length,
                    power=2)
            elif feature.endswith('melspectogram'):
                converter = torchaudio.transforms.MelSpectrogram(
                    sample_rate=sr,
                    n_fft=n_fft,
                    win_length=win_length,
                    hop_length=hop_length,
                    n_mels=n_mels)
            elif feature == 'melscale':
                converter = torchaudio.transforms.MelScale(sample_rate=sr,
                                                           n_mels=n_mels)
            elif feature == 'mfcc':
                converter = torchaudio.transforms.MFCC(sample_rate=sr,
                                                       n_mfcc=n_mfcc)
            else:
                converter = None
                print('Wrong feature type!')
            return converter

        feature_from_timedomain = ['time-domain']
        feature_from_freqdomain = [
            'magnitude', 'log-magnitude', 'spectrogram', 'log-spectrogram',
            'melspectrogram', 'log-melspectrogram', 'mfcc', 'melscale'
        ]
        assert feature in feature_from_timedomain or feature in feature_from_freqdomain
        if feature == 'time-domain':
            features = audio_wavs

        else:
            streams_stft = {}
            phases = {}
            features = {}
            for k, v in audio_wavs.items():
                streams_stft[k] = torch.stft(audio_wavs[k],
                                             n_fft=n_fft,
                                             hop_length=hop_length,
                                             win_length=win_length)
                phases[k] = magphase(streams_stft[k])[1]
                if feature.endswith('magnitude'):
                    features[k] = magphase(streams_stft[k])[0]
                else:
                    f_convert = feature_converter(feature)
                    features[k] = f_convert(audio_wavs[k])
                if feature.startswith('log-'):
                    features[k] = torch.log(features[k] + 1e-5)

            if self.feature != 'mfcc':
                features = [features, phases]

        return features