예제 #1
0
 def forward(self, x, return_maxima=False):
     if not self.training or self.warping_fn is None:
         fbanks = self.fbanks
         x = x.matmul(fbanks)
     else:
         independent_axis = [ax if ax >= 0 else x.ndim+ax for ax in self.independent_axis]
         assert all([ax < x.ndim-1 for ax in independent_axis])
         size = [
             x.shape[i] if i in independent_axis else 1
             for i in range(x.ndim-1)
         ]
         fbanks = get_fbanks(
             sample_rate=self.sample_rate,
             stft_size=self.stft_size,
             number_of_filters=self.number_of_filters,
             lowest_frequency=self.lowest_frequency,
             highest_frequency=self.highest_frequency,
             htk_mel=self.htk_mel,
             warping_fn=self.warping_fn,
             size=size,
         ).astype(np.float32)
         fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6)
         fbanks = torch.from_numpy(fbanks).transpose(-2, -1).to(x.device)
         if fbanks.shape[-3] == 1:
             x = x.matmul(fbanks.squeeze(-3))
         else:
             x = x[..., None, :].matmul(fbanks).squeeze(-2)
     if self.log:
         x = torch.log(x + self.eps)
     if return_maxima:
         maxima = (fbanks.argmax(-2) + 1) * (fbanks.sum(-2) > 0) - 1
         return x, maxima
     return x
예제 #2
0
    def __init__(
        self,
        n_mels: int,
        sample_rate: int,
        fft_length: int,
        fmin: Optional[float] = 50.,
        fmax: Optional[float] = None,
        log: bool = True,
        eps=1e-12,
        *,
        warping_fn=None,
    ):
        """
        Transforms linear spectrogram to (log) mel spectrogram.

        Args:
            sample_rate: sample rate of audio signal
            fft_length: fft_length used in stft
            n_mels: number of filters to be applied
            fmin: lowest frequency (onset of first filter)
            fmax: highest frequency (offset of last filter)
            log: apply log to mel spectrogram
            eps:

        >>> mel_transform = MelTransform(40, 16000, 512)
        >>> spec = torch.zeros((10, 1, 100, 257))
        >>> logmelspec = mel_transform(spec)
        >>> logmelspec.shape
        torch.Size([10, 1, 100, 40])
        >>> rec = mel_transform.inverse(logmelspec)
        >>> rec.shape
        torch.Size([10, 1, 100, 257])
        """
        super().__init__()
        self.sample_rate = sample_rate
        self.fft_length = fft_length
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax
        self.log = log
        self.eps = eps
        self.warping_fn = warping_fn

        fbanks = get_fbanks(
            n_mels=self.n_mels,
            fft_length=self.fft_length,
            sample_rate=self.sample_rate,
            fmin=self.fmin,
            fmax=self.fmax,
        ).astype(np.float32)
        fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6)
        self._fbanks = nn.Parameter(torch.from_numpy(fbanks.T),
                                    requires_grad=False)
예제 #3
0
 def get_fbanks(self, x):
     if not self.training or self.warping_fn is None:
         fbanks = self._fbanks
     else:
         fbanks = get_fbanks(n_mels=self.n_mels,
                             fft_length=self.fft_length,
                             sample_rate=self.sample_rate,
                             fmin=self.fmin,
                             fmax=self.fmax,
                             warping_fn=partial(self.warping_fn,
                                                n=x.shape[0])).astype(
                                                    np.float32)
         fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6)
         fbanks = torch.from_numpy(fbanks).transpose(-2, -1).to(x.device)
         while x.dim() > fbanks.dim():
             fbanks = fbanks[:, None]
     return nn.ReLU()(fbanks)
예제 #4
0
    def __init__(
            self,
            sample_rate: int,
            stft_size: int,
            number_of_filters: int,
            lowest_frequency: Optional[float] = 50.,
            highest_frequency: Optional[float] = None,
            htk_mel=True,
            log: bool = True,
            eps=1e-12,
            *,
            warping_fn=None,
            independent_axis=0,
    ):
        """
        Transforms linear spectrogram to (log) mel spectrogram.

        Args:
            sample_rate: sample rate of audio signal
            stft_size: fft_length used in stft
            number_of_filters: number of filters to be applied
            lowest_frequency: lowest frequency (onset of first filter)
            highest_frequency: highest frequency (offset of last filter)
            log: apply log to mel spectrogram
            eps:

        >>> sample_rate = 16000
        >>> highest_frequency = sample_rate/2
        >>> mel_transform = MelTransform(sample_rate, 512, 40)
        >>> spec = torch.rand((3, 1, 100, 257))
        >>> logmelspec = mel_transform(spec)
        >>> logmelspec.shape
        torch.Size([3, 1, 100, 40])
        >>> rec = mel_transform.inverse(logmelspec)
        >>> rec.shape
        torch.Size([3, 1, 100, 257])
        >>> from paderbox.transform.module_fbank import HzWarping
        >>> from paderbox.utils.random_utils import Uniform
        >>> warping_fn = HzWarping(\
                warp_factor_sampling_fn=Uniform(low=.9, high=1.1),\
                boundary_frequency_ratio_sampling_fn=Uniform(low=.6, high=.7),\
                highest_frequency=highest_frequency,\
            )
        >>> mel_transform = MelTransform(sample_rate, 512, 40, warping_fn=warping_fn)
        >>> mel_transform(spec).shape
        torch.Size([3, 1, 100, 40])
        >>> mel_transform = MelTransform(sample_rate, 512, 40, warping_fn=warping_fn, independent_axis=(0,1,2))
        >>> np.random.seed(0)
        >>> x = mel_transform(spec)
        >>> x.shape
        torch.Size([3, 1, 100, 40])
        >>> from paderbox.transform.module_fbank import MelTransform as MelTransformNumpy
        >>> mel_transform_np = MelTransformNumpy(sample_rate, 512, 40, warping_fn=warping_fn, independent_axis=(0,1,2))
        >>> np.random.seed(0)
        >>> x_ref = mel_transform_np(spec.numpy())
        >>> assert (x.numpy()-x_ref).max() < 1e-6
        """
        super().__init__()
        self.sample_rate = sample_rate
        self.stft_size = stft_size
        self.number_of_filters = number_of_filters
        self.lowest_frequency = lowest_frequency
        self.highest_frequency = highest_frequency
        self.htk_mel = htk_mel
        self.log = log
        self.eps = eps
        self.warping_fn = warping_fn
        self.independent_axis = [independent_axis] if np.isscalar(independent_axis) else independent_axis

        fbanks = get_fbanks(
            sample_rate=self.sample_rate,
            stft_size=self.stft_size,
            number_of_filters=self.number_of_filters,
            lowest_frequency=self.lowest_frequency,
            highest_frequency=self.highest_frequency,
            htk_mel=htk_mel,
        ).astype(np.float32)
        fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6)
        self.fbanks = nn.Parameter(
            torch.from_numpy(fbanks.T), requires_grad=False
        )