def test_log_mel_equal():
    layer1 = LinearSpectrogram(n_fft=4, hop_length=1)
    layer2 = LogMelFbank(n_fft=4, hop_length=1, n_mels=2)
    x = torch.randn(2, 4, 9)
    y1, y1_lens = layer1(x, torch.LongTensor([4, 3]))
    y2, _ = layer2(x, torch.LongTensor([4, 3]))
    y1_2, _ = layer2.logmel(y1, y1_lens)
    np.testing.assert_array_equal(
        y2.detach().cpu().numpy(),
        y1_2.detach().cpu().numpy(),
    )
Beispiel #2
0
    def __init__(
        self,
        fs: int = 22050,
        n_fft: int = 1024,
        hop_length: int = 256,
        win_length: Optional[int] = None,
        window: str = "hann",
        n_mels: int = 80,
        fmin: Optional[int] = 0,
        fmax: Optional[int] = None,
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        log_base: Optional[float] = 10.0,
    ):
        """Initialize Mel-spectrogram loss.

        Args:
            fs (int): Sampling rate.
            n_fft (int): FFT points.
            hop_length (int): Hop length.
            win_length (Optional[int]): Window length.
            window (str): Window type.
            n_mels (int): Number of Mel basis.
            fmin (Optional[int]): Minimum frequency for Mel.
            fmax (Optional[int]): Maximum frequency for Mel.
            center (bool): Whether to use center window.
            normalized (bool): Whether to use normalized one.
            onesided (bool): Whether to use oneseded one.
            log_base (Optional[float]): Log base value.

        """
        super().__init__()
        self.wav_to_mel = LogMelFbank(
            fs=fs,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            center=center,
            normalized=normalized,
            onesided=onesided,
            log_base=log_base,
        )
def test_compatible_with_espnet1():
    layer = LogMelFbank(n_fft=16,
                        hop_length=4,
                        n_mels=4,
                        fs="16k",
                        fmin=80,
                        fmax=7600)
    x = torch.randn(1, 100)
    y, _ = layer(x, torch.LongTensor([100]))
    y = y.numpy()[0]
    y2 = logmelspectrogram(x[0].numpy(),
                           n_fft=16,
                           n_shift=4,
                           n_mels=4,
                           fs=16000,
                           fmin=80,
                           fmax=7600)
    np.testing.assert_allclose(y, y2, rtol=0, atol=1e-5)
Beispiel #4
0
class MelSpectrogramLoss(torch.nn.Module):
    """Mel-spectrogram loss."""
    def __init__(
        self,
        fs: int = 22050,
        n_fft: int = 1024,
        hop_length: int = 256,
        win_length: Optional[int] = None,
        window: str = "hann",
        n_mels: int = 80,
        fmin: Optional[int] = 0,
        fmax: Optional[int] = None,
        center: bool = True,
        normalized: bool = False,
        onesided: bool = True,
        log_base: Optional[float] = 10.0,
    ):
        """Initialize Mel-spectrogram loss.

        Args:
            fs (int): Sampling rate.
            n_fft (int): FFT points.
            hop_length (int): Hop length.
            win_length (Optional[int]): Window length.
            window (str): Window type.
            n_mels (int): Number of Mel basis.
            fmin (Optional[int]): Minimum frequency for Mel.
            fmax (Optional[int]): Maximum frequency for Mel.
            center (bool): Whether to use center window.
            normalized (bool): Whether to use normalized one.
            onesided (bool): Whether to use oneseded one.
            log_base (Optional[float]): Log base value.

        """
        super().__init__()
        self.wav_to_mel = LogMelFbank(
            fs=fs,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            n_mels=n_mels,
            fmin=fmin,
            fmax=fmax,
            center=center,
            normalized=normalized,
            onesided=onesided,
            log_base=log_base,
        )

    def forward(
        self,
        y_hat: torch.Tensor,
        y: torch.Tensor,
        spec: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Calculate Mel-spectrogram loss.

        Args:
            y_hat (Tensor): Generated waveform tensor (B, 1, T).
            y (Tensor): Groundtruth waveform tensor (B, 1, T).
            spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor
                (B, n_fft, T). if provided, use it instead of groundtruth waveform.

        Returns:
            Tensor: Mel-spectrogram loss value.

        """
        mel_hat, _ = self.wav_to_mel(y_hat.squeeze(1))
        if spec is None:
            mel, _ = self.wav_to_mel(y.squeeze(1))
        else:
            mel, _ = self.wav_to_mel.logmel(spec)
        mel_loss = F.l1_loss(mel_hat, mel)

        return mel_loss
def test_forward():
    layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2)
    x = torch.randn(2, 4, 9)
    y, _ = layer(x, torch.LongTensor([4, 3]))
    assert y.shape == (2, 5, 9, 2)
def test_get_parameters():
    layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2, fs="16k")
    print(layer.get_parameters())
def test_output_size():
    layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2, fs="16k")
    print(layer.output_size())
def test_backward_not_leaf_in():
    layer = LogMelFbank(n_fft=4, hop_length=1, n_mels=2)
    x = torch.randn(2, 4, 9, requires_grad=True)
    x = x + 2
    y, _ = layer(x, torch.LongTensor([4, 3]))
    y.sum().backward()