def forward(self, x, return_maxima=False): if not self.training or self.warping_fn is None: fbanks = self.fbanks x = x.matmul(fbanks) else: independent_axis = [ax if ax >= 0 else x.ndim+ax for ax in self.independent_axis] assert all([ax < x.ndim-1 for ax in independent_axis]) size = [ x.shape[i] if i in independent_axis else 1 for i in range(x.ndim-1) ] fbanks = get_fbanks( sample_rate=self.sample_rate, stft_size=self.stft_size, number_of_filters=self.number_of_filters, lowest_frequency=self.lowest_frequency, highest_frequency=self.highest_frequency, htk_mel=self.htk_mel, warping_fn=self.warping_fn, size=size, ).astype(np.float32) fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6) fbanks = torch.from_numpy(fbanks).transpose(-2, -1).to(x.device) if fbanks.shape[-3] == 1: x = x.matmul(fbanks.squeeze(-3)) else: x = x[..., None, :].matmul(fbanks).squeeze(-2) if self.log: x = torch.log(x + self.eps) if return_maxima: maxima = (fbanks.argmax(-2) + 1) * (fbanks.sum(-2) > 0) - 1 return x, maxima return x
def __init__( self, n_mels: int, sample_rate: int, fft_length: int, fmin: Optional[float] = 50., fmax: Optional[float] = None, log: bool = True, eps=1e-12, *, warping_fn=None, ): """ Transforms linear spectrogram to (log) mel spectrogram. Args: sample_rate: sample rate of audio signal fft_length: fft_length used in stft n_mels: number of filters to be applied fmin: lowest frequency (onset of first filter) fmax: highest frequency (offset of last filter) log: apply log to mel spectrogram eps: >>> mel_transform = MelTransform(40, 16000, 512) >>> spec = torch.zeros((10, 1, 100, 257)) >>> logmelspec = mel_transform(spec) >>> logmelspec.shape torch.Size([10, 1, 100, 40]) >>> rec = mel_transform.inverse(logmelspec) >>> rec.shape torch.Size([10, 1, 100, 257]) """ super().__init__() self.sample_rate = sample_rate self.fft_length = fft_length self.n_mels = n_mels self.fmin = fmin self.fmax = fmax self.log = log self.eps = eps self.warping_fn = warping_fn fbanks = get_fbanks( n_mels=self.n_mels, fft_length=self.fft_length, sample_rate=self.sample_rate, fmin=self.fmin, fmax=self.fmax, ).astype(np.float32) fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6) self._fbanks = nn.Parameter(torch.from_numpy(fbanks.T), requires_grad=False)
def get_fbanks(self, x): if not self.training or self.warping_fn is None: fbanks = self._fbanks else: fbanks = get_fbanks(n_mels=self.n_mels, fft_length=self.fft_length, sample_rate=self.sample_rate, fmin=self.fmin, fmax=self.fmax, warping_fn=partial(self.warping_fn, n=x.shape[0])).astype( np.float32) fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6) fbanks = torch.from_numpy(fbanks).transpose(-2, -1).to(x.device) while x.dim() > fbanks.dim(): fbanks = fbanks[:, None] return nn.ReLU()(fbanks)
def __init__( self, sample_rate: int, stft_size: int, number_of_filters: int, lowest_frequency: Optional[float] = 50., highest_frequency: Optional[float] = None, htk_mel=True, log: bool = True, eps=1e-12, *, warping_fn=None, independent_axis=0, ): """ Transforms linear spectrogram to (log) mel spectrogram. Args: sample_rate: sample rate of audio signal stft_size: fft_length used in stft number_of_filters: number of filters to be applied lowest_frequency: lowest frequency (onset of first filter) highest_frequency: highest frequency (offset of last filter) log: apply log to mel spectrogram eps: >>> sample_rate = 16000 >>> highest_frequency = sample_rate/2 >>> mel_transform = MelTransform(sample_rate, 512, 40) >>> spec = torch.rand((3, 1, 100, 257)) >>> logmelspec = mel_transform(spec) >>> logmelspec.shape torch.Size([3, 1, 100, 40]) >>> rec = mel_transform.inverse(logmelspec) >>> rec.shape torch.Size([3, 1, 100, 257]) >>> from paderbox.transform.module_fbank import HzWarping >>> from paderbox.utils.random_utils import Uniform >>> warping_fn = HzWarping(\ warp_factor_sampling_fn=Uniform(low=.9, high=1.1),\ boundary_frequency_ratio_sampling_fn=Uniform(low=.6, high=.7),\ highest_frequency=highest_frequency,\ ) >>> mel_transform = MelTransform(sample_rate, 512, 40, warping_fn=warping_fn) >>> mel_transform(spec).shape torch.Size([3, 1, 100, 40]) >>> mel_transform = MelTransform(sample_rate, 512, 40, warping_fn=warping_fn, independent_axis=(0,1,2)) >>> np.random.seed(0) >>> x = mel_transform(spec) >>> x.shape torch.Size([3, 1, 100, 40]) >>> from paderbox.transform.module_fbank import MelTransform as MelTransformNumpy >>> mel_transform_np = MelTransformNumpy(sample_rate, 512, 40, warping_fn=warping_fn, independent_axis=(0,1,2)) >>> np.random.seed(0) >>> x_ref = mel_transform_np(spec.numpy()) >>> assert (x.numpy()-x_ref).max() < 1e-6 """ super().__init__() self.sample_rate = sample_rate self.stft_size = stft_size self.number_of_filters = number_of_filters self.lowest_frequency = lowest_frequency self.highest_frequency = highest_frequency self.htk_mel = htk_mel self.log = log self.eps = eps self.warping_fn = warping_fn self.independent_axis = [independent_axis] if np.isscalar(independent_axis) else independent_axis fbanks = get_fbanks( sample_rate=self.sample_rate, stft_size=self.stft_size, number_of_filters=self.number_of_filters, lowest_frequency=self.lowest_frequency, highest_frequency=self.highest_frequency, htk_mel=htk_mel, ).astype(np.float32) fbanks = fbanks / (fbanks.sum(axis=-1, keepdims=True) + 1e-6) self.fbanks = nn.Parameter( torch.from_numpy(fbanks.T), requires_grad=False )