Ejemplo n.º 1
0
class LogMelSpectrogram2(nn.Module):
    def __init__(self,
                 sample_rate: int,
                 mel_size: int,
                 n_fft: int,
                 win_length: int,
                 hop_length: int,
                 mel_min: float = 0.,
                 mel_max: float = None,
                 normalize: bool = True):
        super().__init__()
        self.mel_size = mel_size
        self.normalize = normalize

        self.stft = CustomSTFT(filter_length=n_fft,
                               hop_length=hop_length,
                               win_length=win_length)

        # mel filter banks
        mel_filter = librosa.filters.mel(sample_rate,
                                         n_fft,
                                         mel_size,
                                         fmin=mel_min,
                                         fmax=mel_max)
        self.register_buffer('mel_filter',
                             torch.tensor(mel_filter, dtype=torch.float))
        if normalize:
            self.register_buffer(
                'mean',
                torch.FloatTensor(np.array(
                    settings.VCTK_MEL_MEAN)).unsqueeze(0).unsqueeze(2))
            self.register_buffer(
                'std',
                torch.FloatTensor(np.array(
                    settings.VCTK_MEL_STD)).unsqueeze(0).unsqueeze(2))

    def forward(self, wav: torch.tensor, eps: float = 1e-10) -> torch.tensor:
        mag, phase = self.stft.transform(wav)

        # apply mel filter
        mel = torch.matmul(self.mel_filter, mag)

        # to log-space
        mel = torch.log10(mel.clamp_min(eps))
        if self.normalize:
            mel = (mel - self.mean) / self.std
        return mel
Ejemplo n.º 2
0
class Inferencer:
    """
    This class is made to build "mel function" and "vocoder" simply

    Methods
    -------
    encode(wav_tensor: torch.Tensor)
        :arg wav_tensor dimension 1 or 2
        :return log mel spectrum
    decode(mel_tensor: torch.Tensor)
        :arg mel_tensor (N, C, Tm)
        :return predicted wav_tensor (N, 1, Tw)

    Examples::
        inferencer = Inferencer()

        # load audio and make tensor
        wav, sr = librosa.load(audio_path, sr=22050)
        wav_tensor = torch.FloatTensor(wav).unsqueeze(0).cuda()

        # convert to mel
        mel = inferencer.encode(wav_tensor)

        # convert back to wav
        pred_wav = inferencer.decode(mel)
    """
    def __init__(self, device: str = 'cuda'):
        self.device = device

        # make mel converter
        self.mel_func = LogMelSpectrogram(settings.SAMPLE_RATE,
                                          settings.MEL_SIZE, settings.N_FFT,
                                          settings.WIN_LENGTH,
                                          settings.HOP_LENGTH,
                                          float(settings.MEL_MIN),
                                          float(settings.MEL_MAX)).to(device)

        # PQMF module
        self.pqmf_func = PQMF().to(device)

        # load model
        self.gen = Generator().to(device)
        chk = torch.load(VCTK_BASE_CHK_PATH, map_location=torch.device(device))
        self.gen.load_state_dict(chk)
        self.gen.eval()

        self.stft = STFT(settings.WIN_LENGTH, settings.HOP_LENGTH).to(device)

        # denoise - reference https://github.com/NVIDIA/waveglow/blob/master/denoiser.py
        mel_input = torch.zeros((1, 80, 88)).float().to(device)
        with torch.no_grad():
            bias_audio = self.decode(mel_input, is_denoise=False).squeeze(1)
            bias_spec, _ = self.stft.transform(bias_audio)

        self.bias_spec = bias_spec[:, :, 0][:, :, None]

    def encode(self, wav_tensor: torch.Tensor) -> torch.Tensor:
        """
        Convert wav tensor to mel tensor
        :param wav_tensor: wav tensor (N, T)
        :return: mel tensor (N, C, T)
        """
        if len(wav_tensor.size()) == 1:
            wav_tensor = wav_tensor.unsqueeze(0)
        assert len(
            wav_tensor.size()) <= 2, 'The expected dimension of wav is 1 or 2'
        return self.mel_func(wav_tensor)

    def decode(self,
               mel_tensor: torch.Tensor,
               is_denoise: bool = True) -> torch.Tensor:
        """
        Convert mel tensor to wav tensor by using multi-band melgan
        :param mel_tensor: mel tensor (N, C, T)
        :param is_denoise: using denoise function
        :return: wav tensor (N, T)
        """
        # TODO: Multiband-melgan returns noises from zero tensor. Get some tries on training time.
        # inference generator and pqmf
        with torch.no_grad():
            pred = self.gen(mel_tensor)
            pred = self.pqmf_func.synthesis(pred).squeeze(1)

        # denoising
        if is_denoise:
            pred = self.denoise(pred)

        return pred

    def denoise(self,
                audio: torch.Tensor,
                strength: float = 0.1) -> torch.Tensor:
        audio_spec, audio_angles = self.stft.transform(
            audio.to(self.device).float())
        audio_spec_denoised = audio_spec - self.bias_spec * strength
        audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
        audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles)
        return audio_denoised
Ejemplo n.º 3
0
class SpectrogramUnet(nn.Module):
    def __init__(self,
                 spec_dim: int,
                 hidden_dim: int,
                 filter_len: int,
                 hop_len: int,
                 layers: int = 3,
                 block_layers: int = 3,
                 kernel_size: int = 5,
                 is_mask: bool = False,
                 norm: str = 'bn',
                 act: str = 'tanh'):
        super().__init__()
        self.layers = layers
        self.is_mask = is_mask

        # stft modules
        self.stft = STFT(filter_len, hop_len)

        if norm == 'bn':
            self.bn_func = nn.BatchNorm1d
        elif norm == 'ins':
            self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
        else:
            raise NotImplementedError('{} is not implemented !'.format(norm))

        if act == 'tanh':
            self.act_func = nn.Tanh
            self.act_out = nn.Tanh
        elif act == 'comp':
            self.act_func = ComplexActLayer
            self.act_out = lambda: ComplexActLayer(is_out=True)
        else:
            raise NotImplementedError('{} is not implemented !'.format(act))

        # prev conv
        self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)

        # down
        self.down = nn.ModuleList()
        self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
        for idx in range(self.layers):
            block = ComplexConvBlock(hidden_dim,
                                     hidden_dim,
                                     kernel_size=kernel_size,
                                     padding=kernel_size // 2,
                                     bn_func=self.bn_func,
                                     act_func=self.act_func,
                                     layers=block_layers)
            self.down.append(block)

        # up
        self.up = nn.ModuleList()
        for idx in range(self.layers):
            in_c = hidden_dim if idx == 0 else hidden_dim * 2
            self.up.append(
                nn.Sequential(
                    ComplexConvBlock(in_c,
                                     hidden_dim,
                                     kernel_size=kernel_size,
                                     padding=kernel_size // 2,
                                     bn_func=self.bn_func,
                                     act_func=self.act_func,
                                     layers=block_layers),
                    self.bn_func(hidden_dim),
                    self.act_func(),
                    ComplexTransposedConv1d(hidden_dim,
                                            hidden_dim,
                                            kernel_size=2,
                                            stride=2),
                ))

        # out_conv
        self.out_conv = nn.Sequential(
            ComplexConvBlock(hidden_dim * 2,
                             spec_dim * 2,
                             kernel_size=kernel_size,
                             padding=kernel_size // 2,
                             bn_func=self.bn_func,
                             act_func=self.act_func),
            self.bn_func(spec_dim * 2), self.act_func())

        # refine conv
        self.refine_conv = nn.Sequential(
            ComplexConvBlock(spec_dim * 4,
                             spec_dim * 2,
                             kernel_size=kernel_size,
                             padding=kernel_size // 2,
                             bn_func=self.bn_func,
                             act_func=self.act_func),
            self.bn_func(spec_dim * 2), self.act_func())

    def log_stft(self, wav):
        # stft
        mag, phase = self.stft.transform(wav)
        return torch.log(mag + 1), phase

    def exp_istft(self, log_mag, phase):
        # exp
        mag = np.e**log_mag - 1
        # istft
        wav = self.stft.inverse(mag, phase)
        return wav

    def adjust_diff(self, x, target):
        size_diff = (target.size()[-1] - x.size()[-1])
        assert size_diff >= 0
        if size_diff > 0:
            x = F.pad(x.unsqueeze(1), (size_diff // 2, size_diff // 2),
                      'reflect').squeeze(1)
        return x

    def masking(self, mag, phase, origin_mag, origin_phase):
        abs_mag = torch.abs(mag)
        mag_mask = torch.tanh(abs_mag)
        phase_mask = mag / abs_mag

        # masking
        mag = mag_mask * origin_mag
        phase = phase_mask * (origin_phase + phase)
        return mag, phase

    def forward(self, wav):
        # stft
        origin_mag, origin_phase = self.log_stft(wav)
        origin_x = torch.cat([origin_mag, origin_phase], dim=1)

        # prev
        x = self.prev_conv(origin_x)

        # body
        # down
        down_cache = []
        for idx, block in enumerate(self.down):
            x = block(x)
            down_cache.append(x)
            x = self.down_pool(x)

        # up
        for idx, block in enumerate(self.up):
            x = block(x)
            res = F.interpolate(down_cache[self.layers - (idx + 1)],
                                size=[x.size()[2]],
                                mode='linear',
                                align_corners=False)
            x = concat_complex(x, res, dim=1)

        # match spec dimension
        x = self.out_conv(x)
        if origin_mag.size(2) != x.size(2):
            x = F.interpolate(x,
                              size=[origin_mag.size(2)],
                              mode='linear',
                              align_corners=False)

        # refine
        x = self.refine_conv(concat_complex(x, origin_x))

        def to_wav(stft):
            mag, phase = stft.chunk(2, 1)
            if self.is_mask:
                mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
            out = self.exp_istft(mag, phase)
            out = self.adjust_diff(out, wav)
            return out

        refine_wav = to_wav(x)

        return refine_wav