Esempio n. 1
0
    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))
Esempio n. 2
0
    def test_timestretch_non_zero(self, rate, test_pseudo_complex):
        """Verify that ``T.TimeStretch`` does not fail if it's not close to 0

        ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
        for cases where input is not zero.

        As tested above, when spectrogram contains values close to zero, the gradients are unstable
        and gradcheck fails.

        In this test, we generate spectrogram from random signal, then we push the points around
        zero away from the origin.

        This process does not reflect the real use-case, and it is not practical for users, but
        this helps us understand to what degree the function is differentiable and when not.
        """
        n_fft = 16
        transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
        waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2)
        spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)

        # 1e-3 is too small (on CPU)
        epsilon = 1e-2
        too_close = spectrogram.abs() < epsilon
        spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
        if test_pseudo_complex:
            spectrogram = torch.view_as_real(spectrogram)
        self.assert_grad(transform, [spectrogram])
Esempio n. 3
0
 def test_TimeStretch(self, test_pseudo_complex):
     n_freq = 400
     hop_length = 512
     fixed_rate = 1.3
     tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
     self._assert_consistency_complex(
         T.TimeStretch(n_freq=n_freq,
                       hop_length=hop_length,
                       fixed_rate=fixed_rate), tensor, test_pseudo_complex)
 def test_TimeStretch(self):
     n_freq = 400
     hop_length = 512
     fixed_rate = 1.3
     tensor = torch.rand((10, 2, n_freq, 10, 2))
     self._assert_consistency(
         T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
         tensor,
     )
Esempio n. 5
0
    def test_timestretch_zeros_fail(self):
        """Test that ``T.TimeStretch`` fails gradcheck at 0

        This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
        which performs ``atan2(img, real)``, and gradient is not defined at 0.
        """
        n_fft = 16
        transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
        waveform = torch.zeros(2, 40)
        spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
        self.assert_grad(transform, [spectrogram])
Esempio n. 6
0
    def test_TimeStretch(self):
        n_fft = 1025
        n_freq = n_fft // 2 + 1
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
        batch = 10
        num_channels = 2

        waveform = common_utils.get_whitenoise(sample_rate=8000,
                                               n_channels=batch * num_channels)
        tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
        tensor = tensor.reshape(batch, num_channels, n_freq, -1)
        self._assert_consistency_complex(
            T.TimeStretch(n_freq=n_freq,
                          hop_length=hop_length,
                          fixed_rate=fixed_rate),
            tensor,
        )
######################################################################
# SpecAugment
# -----------
#
# `SpecAugment <https://ai.googleblog.com/2019/04/specaugment-new-data-augmentation.html>`__
# is a popular spectrogram augmentation technique.
#
# ``torchaudio`` implements ``TimeStretch``, ``TimeMasking`` and
# ``FrequencyMasking``.
#
# TimeStretch
# ~~~~~~~~~~
#

spec = get_spectrogram(power=None)
stretch = T.TimeStretch()

rate = 1.2
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)

plot_spectrogram(torch.abs(spec[0]), title="Original", aspect='equal', xmax=304)

rate = 0.9
spec_ = stretch(spec, rate)
plot_spectrogram(torch.abs(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)

######################################################################
# TimeMasking
# ~~~~~~~~~~~
#