def test_log_mel_equal(): layer1 = LinearSpectrogram(n_fft=4, hop_length=1) layer2 = LogMelFbank(n_fft=4, hop_length=1, n_mels=2) x = torch.randn(2, 4, 9) y1, y1_lens = layer1(x, torch.LongTensor([4, 3])) y2, _ = layer2(x, torch.LongTensor([4, 3])) y1_2, _ = layer2.logmel(y1, y1_lens) np.testing.assert_array_equal( y2.detach().cpu().numpy(), y1_2.detach().cpu().numpy(), )
def __init__( self, fs: int = 22050, n_fft: int = 1024, hop_length: int = 256, win_length: Optional[int] = None, window: str = "hann", n_mels: int = 80, fmin: Optional[int] = 0, fmax: Optional[int] = None, center: bool = True, normalized: bool = False, onesided: bool = True, log_base: Optional[float] = 10.0, ): """Initialize Mel-spectrogram loss. Args: fs (int): Sampling rate. n_fft (int): FFT points. hop_length (int): Hop length. win_length (Optional[int]): Window length. window (str): Window type. n_mels (int): Number of Mel basis. fmin (Optional[int]): Minimum frequency for Mel. fmax (Optional[int]): Maximum frequency for Mel. center (bool): Whether to use center window. normalized (bool): Whether to use normalized one. onesided (bool): Whether to use oneseded one. log_base (Optional[float]): Log base value. """ super().__init__() self.wav_to_mel = LogMelFbank( fs=fs, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, n_mels=n_mels, fmin=fmin, fmax=fmax, center=center, normalized=normalized, onesided=onesided, log_base=log_base, )
def test_compatible_with_espnet1(): layer = LogMelFbank(n_fft=16, hop_length=4, n_mels=4, fs="16k", fmin=80, fmax=7600) x = torch.randn(1, 100) y, _ = layer(x, torch.LongTensor([100])) y = y.numpy()[0] y2 = logmelspectrogram(x[0].numpy(), n_fft=16, n_shift=4, n_mels=4, fs=16000, fmin=80, fmax=7600) np.testing.assert_allclose(y, y2, rtol=0, atol=1e-5)
class MelSpectrogramLoss(torch.nn.Module): """Mel-spectrogram loss.""" def __init__( self, fs: int = 22050, n_fft: int = 1024, hop_length: int = 256, win_length: Optional[int] = None, window: str = "hann", n_mels: int = 80, fmin: Optional[int] = 0, fmax: Optional[int] = None, center: bool = True, normalized: bool = False, onesided: bool = True, log_base: Optional[float] = 10.0, ): """Initialize Mel-spectrogram loss. Args: fs (int): Sampling rate. n_fft (int): FFT points. hop_length (int): Hop length. win_length (Optional[int]): Window length. window (str): Window type. n_mels (int): Number of Mel basis. fmin (Optional[int]): Minimum frequency for Mel. fmax (Optional[int]): Maximum frequency for Mel. center (bool): Whether to use center window. normalized (bool): Whether to use normalized one. onesided (bool): Whether to use oneseded one. log_base (Optional[float]): Log base value. """ super().__init__() self.wav_to_mel = LogMelFbank( fs=fs, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, n_mels=n_mels, fmin=fmin, fmax=fmax, center=center, normalized=normalized, onesided=onesided, log_base=log_base, ) def forward( self, y_hat: torch.Tensor, y: torch.Tensor, spec: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Calculate Mel-spectrogram loss. Args: y_hat (Tensor): Generated waveform tensor (B, 1, T). y (Tensor): Groundtruth waveform tensor (B, 1, T). spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor (B, n_fft, T). if provided, use it instead of groundtruth waveform. Returns: Tensor: Mel-spectrogram loss value. """ mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1)) if spec is None: mel, _ = self.wav_to_mel(y.squeeze(1)) else: mel, _ = self.wav_to_mel.logmel(spec) mel_loss = F.l1_loss(mel_hat, mel) return mel_loss
def test_forward(): layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2) x = torch.randn(2, 4, 9) y, _ = layer(x, torch.LongTensor([4, 3])) assert y.shape == (2, 5, 9, 2)
def test_get_parameters(): layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2, fs="16k") print(layer.get_parameters())
def test_output_size(): layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2, fs="16k") print(layer.output_size())
def test_backward_not_leaf_in(): layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2) x = torch.randn(2, 4, 9, requires_grad=True) x = x + 2 y, _ = layer(x, torch.LongTensor([4, 3])) y.sum().backward()