def __init__(self, device: str = 'cuda'): self.device = device # make mel converter self.mel_func = LogMelSpectrogram(settings.SAMPLE_RATE, settings.MEL_SIZE, settings.N_FFT, settings.WIN_LENGTH, settings.HOP_LENGTH, float(settings.MEL_MIN), float(settings.MEL_MAX)).to(device) # PQMF module self.pqmf_func = PQMF().to(device) # load model self.gen = Generator().to(device) chk = torch.load(VCTK_BASE_CHK_PATH, map_location=torch.device(device)) self.gen.load_state_dict(chk) self.gen.eval() self.stft = STFT(settings.WIN_LENGTH, settings.HOP_LENGTH).to(device) # denoise - reference https://github.com/NVIDIA/waveglow/blob/master/denoiser.py mel_input = torch.zeros((1, 80, 88)).float().to(device) with torch.no_grad(): bias_audio = self.decode(mel_input, is_denoise=False).squeeze(1) bias_spec, _ = self.stft.transform(bias_audio) self.bias_spec = bias_spec[:, :, 0][:, :, None]
def __init__(self, sample_rate: int, mel_size: int, n_fft: int, win_length: int, hop_length: int, mel_min: float = 0., mel_max: float = None, normalize: bool = True): super().__init__() self.mel_size = mel_size self.normalize = normalize self.stft = CustomSTFT(filter_length=n_fft, hop_length=hop_length, win_length=win_length) # mel filter banks mel_filter = librosa.filters.mel(sample_rate, n_fft, mel_size, fmin=mel_min, fmax=mel_max) self.register_buffer('mel_filter', torch.tensor(mel_filter, dtype=torch.float)) if normalize: self.register_buffer( 'mean', torch.FloatTensor(np.array( settings.VCTK_MEL_MEAN)).unsqueeze(0).unsqueeze(2)) self.register_buffer( 'std', torch.FloatTensor(np.array( settings.VCTK_MEL_STD)).unsqueeze(0).unsqueeze(2))
def build_stft_functions(*params: Tuple[int, int, int]): """ Make stft modules by given parameters :param params: arguments of tuples (n_fft, window size, hop size) :return: STFT modules """ print('Build Mel Functions ...') return [ STFT( win, hop, win, fft ).cuda() for fft, win, hop in params ]
def build_stft_functions(): print('Build Mel Functions ...') mel_funcs_for_loss = [ STFT(fft, hop, win).cuda() for fft, win, hop in FB_STFT_PARAMS + MB_STFT_PARAMS ] mel_func = LogMelSpectrogram(settings.SAMPLE_RATE, settings.MEL_SIZE, settings.N_FFT, settings.WIN_LENGTH, settings.HOP_LENGTH, mel_min=float(settings.MEL_MIN), mel_max=float(settings.MEL_MAX)).cuda() return mel_func, mel_funcs_for_loss
class LogMelSpectrogram2(nn.Module): def __init__(self, sample_rate: int, mel_size: int, n_fft: int, win_length: int, hop_length: int, mel_min: float = 0., mel_max: float = None, normalize: bool = True): super().__init__() self.mel_size = mel_size self.normalize = normalize self.stft = CustomSTFT(filter_length=n_fft, hop_length=hop_length, win_length=win_length) # mel filter banks mel_filter = librosa.filters.mel(sample_rate, n_fft, mel_size, fmin=mel_min, fmax=mel_max) self.register_buffer('mel_filter', torch.tensor(mel_filter, dtype=torch.float)) if normalize: self.register_buffer( 'mean', torch.FloatTensor(np.array( settings.VCTK_MEL_MEAN)).unsqueeze(0).unsqueeze(2)) self.register_buffer( 'std', torch.FloatTensor(np.array( settings.VCTK_MEL_STD)).unsqueeze(0).unsqueeze(2)) def forward(self, wav: torch.tensor, eps: float = 1e-10) -> torch.tensor: mag, phase = self.stft.transform(wav) # apply mel filter mel = torch.matmul(self.mel_filter, mag) # to log-space mel = torch.log10(mel.clamp_min(eps)) if self.normalize: mel = (mel - self.mean) / self.std return mel
class Inferencer: """ This class is made to build "mel function" and "vocoder" simply Methods ------- encode(wav_tensor: torch.Tensor) :arg wav_tensor dimension 1 or 2 :return log mel spectrum decode(mel_tensor: torch.Tensor) :arg mel_tensor (N, C, Tm) :return predicted wav_tensor (N, 1, Tw) Examples:: inferencer = Inferencer() # load audio and make tensor wav, sr = librosa.load(audio_path, sr=22050) wav_tensor = torch.FloatTensor(wav).unsqueeze(0).cuda() # convert to mel mel = inferencer.encode(wav_tensor) # convert back to wav pred_wav = inferencer.decode(mel) """ def __init__(self, device: str = 'cuda'): self.device = device # make mel converter self.mel_func = LogMelSpectrogram(settings.SAMPLE_RATE, settings.MEL_SIZE, settings.N_FFT, settings.WIN_LENGTH, settings.HOP_LENGTH, float(settings.MEL_MIN), float(settings.MEL_MAX)).to(device) # PQMF module self.pqmf_func = PQMF().to(device) # load model self.gen = Generator().to(device) chk = torch.load(VCTK_BASE_CHK_PATH, map_location=torch.device(device)) self.gen.load_state_dict(chk) self.gen.eval() self.stft = STFT(settings.WIN_LENGTH, settings.HOP_LENGTH).to(device) # denoise - reference https://github.com/NVIDIA/waveglow/blob/master/denoiser.py mel_input = torch.zeros((1, 80, 88)).float().to(device) with torch.no_grad(): bias_audio = self.decode(mel_input, is_denoise=False).squeeze(1) bias_spec, _ = self.stft.transform(bias_audio) self.bias_spec = bias_spec[:, :, 0][:, :, None] def encode(self, wav_tensor: torch.Tensor) -> torch.Tensor: """ Convert wav tensor to mel tensor :param wav_tensor: wav tensor (N, T) :return: mel tensor (N, C, T) """ if len(wav_tensor.size()) == 1: wav_tensor = wav_tensor.unsqueeze(0) assert len( wav_tensor.size()) <= 2, 'The expected dimension of wav is 1 or 2' return self.mel_func(wav_tensor) def decode(self, mel_tensor: torch.Tensor, is_denoise: bool = True) -> torch.Tensor: """ Convert mel tensor to wav tensor by using multi-band melgan :param mel_tensor: mel tensor (N, C, T) :param is_denoise: using denoise function :return: wav tensor (N, T) """ # TODO: Multiband-melgan returns noises from zero tensor. Get some tries on training time. # inference generator and pqmf with torch.no_grad(): pred = self.gen(mel_tensor) pred = self.pqmf_func.synthesis(pred).squeeze(1) # denoising if is_denoise: pred = self.denoise(pred) return pred def denoise(self, audio: torch.Tensor, strength: float = 0.1) -> torch.Tensor: audio_spec, audio_angles = self.stft.transform( audio.to(self.device).float()) audio_spec_denoised = audio_spec - self.bias_spec * strength audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) return audio_denoised
def __init__(self, spec_dim: int, hidden_dim: int, filter_len: int, hop_len: int, layers: int = 3, block_layers: int = 3, kernel_size: int = 5, is_mask: bool = False, norm: str = 'bn', act: str = 'tanh'): super().__init__() self.layers = layers self.is_mask = is_mask # stft modules self.stft = STFT(filter_len, hop_len) if norm == 'bn': self.bn_func = nn.BatchNorm1d elif norm == 'ins': self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True) else: raise NotImplementedError('{} is not implemented !'.format(norm)) if act == 'tanh': self.act_func = nn.Tanh self.act_out = nn.Tanh elif act == 'comp': self.act_func = ComplexActLayer self.act_out = lambda: ComplexActLayer(is_out=True) else: raise NotImplementedError('{} is not implemented !'.format(act)) # prev conv self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1) # down self.down = nn.ModuleList() self.down_pool = nn.MaxPool1d(3, stride=2, padding=1) for idx in range(self.layers): block = ComplexConvBlock(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func, layers=block_layers) self.down.append(block) # up self.up = nn.ModuleList() for idx in range(self.layers): in_c = hidden_dim if idx == 0 else hidden_dim * 2 self.up.append( nn.Sequential( ComplexConvBlock(in_c, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func, layers=block_layers), self.bn_func(hidden_dim), self.act_func(), ComplexTransposedConv1d(hidden_dim, hidden_dim, kernel_size=2, stride=2), )) # out_conv self.out_conv = nn.Sequential( ComplexConvBlock(hidden_dim * 2, spec_dim * 2, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func), self.bn_func(spec_dim * 2), self.act_func()) # refine conv self.refine_conv = nn.Sequential( ComplexConvBlock(spec_dim * 4, spec_dim * 2, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func), self.bn_func(spec_dim * 2), self.act_func())
class SpectrogramUnet(nn.Module): def __init__(self, spec_dim: int, hidden_dim: int, filter_len: int, hop_len: int, layers: int = 3, block_layers: int = 3, kernel_size: int = 5, is_mask: bool = False, norm: str = 'bn', act: str = 'tanh'): super().__init__() self.layers = layers self.is_mask = is_mask # stft modules self.stft = STFT(filter_len, hop_len) if norm == 'bn': self.bn_func = nn.BatchNorm1d elif norm == 'ins': self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True) else: raise NotImplementedError('{} is not implemented !'.format(norm)) if act == 'tanh': self.act_func = nn.Tanh self.act_out = nn.Tanh elif act == 'comp': self.act_func = ComplexActLayer self.act_out = lambda: ComplexActLayer(is_out=True) else: raise NotImplementedError('{} is not implemented !'.format(act)) # prev conv self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1) # down self.down = nn.ModuleList() self.down_pool = nn.MaxPool1d(3, stride=2, padding=1) for idx in range(self.layers): block = ComplexConvBlock(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func, layers=block_layers) self.down.append(block) # up self.up = nn.ModuleList() for idx in range(self.layers): in_c = hidden_dim if idx == 0 else hidden_dim * 2 self.up.append( nn.Sequential( ComplexConvBlock(in_c, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func, layers=block_layers), self.bn_func(hidden_dim), self.act_func(), ComplexTransposedConv1d(hidden_dim, hidden_dim, kernel_size=2, stride=2), )) # out_conv self.out_conv = nn.Sequential( ComplexConvBlock(hidden_dim * 2, spec_dim * 2, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func), self.bn_func(spec_dim * 2), self.act_func()) # refine conv self.refine_conv = nn.Sequential( ComplexConvBlock(spec_dim * 4, spec_dim * 2, kernel_size=kernel_size, padding=kernel_size // 2, bn_func=self.bn_func, act_func=self.act_func), self.bn_func(spec_dim * 2), self.act_func()) def log_stft(self, wav): # stft mag, phase = self.stft.transform(wav) return torch.log(mag + 1), phase def exp_istft(self, log_mag, phase): # exp mag = np.e**log_mag - 1 # istft wav = self.stft.inverse(mag, phase) return wav def adjust_diff(self, x, target): size_diff = (target.size()[-1] - x.size()[-1]) assert size_diff >= 0 if size_diff > 0: x = F.pad(x.unsqueeze(1), (size_diff // 2, size_diff // 2), 'reflect').squeeze(1) return x def masking(self, mag, phase, origin_mag, origin_phase): abs_mag = torch.abs(mag) mag_mask = torch.tanh(abs_mag) phase_mask = mag / abs_mag # masking mag = mag_mask * origin_mag phase = phase_mask * (origin_phase + phase) return mag, phase def forward(self, wav): # stft origin_mag, origin_phase = self.log_stft(wav) origin_x = torch.cat([origin_mag, origin_phase], dim=1) # prev x = self.prev_conv(origin_x) # body # down down_cache = [] for idx, block in enumerate(self.down): x = block(x) down_cache.append(x) x = self.down_pool(x) # up for idx, block in enumerate(self.up): x = block(x) res = F.interpolate(down_cache[self.layers - (idx + 1)], size=[x.size()[2]], mode='linear', align_corners=False) x = concat_complex(x, res, dim=1) # match spec dimension x = self.out_conv(x) if origin_mag.size(2) != x.size(2): x = F.interpolate(x, size=[origin_mag.size(2)], mode='linear', align_corners=False) # refine x = self.refine_conv(concat_complex(x, origin_x)) def to_wav(stft): mag, phase = stft.chunk(2, 1) if self.is_mask: mag, phase = self.masking(mag, phase, origin_mag, origin_phase) out = self.exp_istft(mag, phase) out = self.adjust_diff(out, wav) return out refine_wav = to_wav(x) return refine_wav