class Denoiser(torch.nn.Module): """ Removes model bias from audio produced with waveglow """ def __init__(self, melgan, filter_length=1024, n_overlap=4, win_length=1024, mode='zeros'): super(Denoiser, self).__init__() self.stft = STFT(filter_length=filter_length, hop_length=int(filter_length / n_overlap), win_length=win_length).cuda() if mode == 'zeros': mel_input = torch.zeros((1, 80, 88)).cuda() elif mode == 'normal': mel_input = torch.randn((1, 80, 88)).cuda() else: raise Exception("Mode {} if not supported".format(mode)) with torch.no_grad(): bias_audio = melgan.inference(mel_input).float() # [B, 1, T] bias_spec, _ = self.stft.transform(bias_audio.squeeze(0)) self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) def forward(self, audio, strength=0.1): audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) audio_spec_denoised = audio_spec.cuda() - self.bias_spec * strength audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles.cuda()) return audio_denoised
class TacotronSTFT(torch.nn.Module): def __init__(self, filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, sampling_rate=44800, mel_fmin=0.0, mel_fmax=8000.0): super(TacotronSTFT, self).__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.stft_fn = STFT(filter_length, hop_length, win_length) mel_basis = librosa_mel_fn(sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) mel_basis = torch.from_numpy(mel_basis).float() self.register_buffer('mel_basis', mel_basis) def spectral_normalize(self, magnitudes): output = dynamic_range_compression(magnitudes) return output def spectral_de_normalize(self, magnitudes): output = dynamic_range_decompression(magnitudes) return output def mel_spectrogram(self, y): """Computes mel-spectrograms from a batch of waves PARAMS ------ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] RETURNS ------- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) """ assert (torch.min(y.data) >= -1) assert (torch.max(y.data) <= 1) magnitudes, phases = self.stft_fn.transform(y) magnitudes = magnitudes.data mel_output = torch.matmul(self.mel_basis, magnitudes) mel_output = self.spectral_normalize(mel_output) return mel_output
def main(config, epoch): root_dir = Path(config["experiments_dir"]) / config["name"] enhancement_dir = root_dir / "enhancements" checkpoints_dir = root_dir / "checkpoints" """============== 加载数据集 ==============""" dataset = initialize_config(config["dataset"]) dataloader = DataLoader( dataset=dataset, batch_size=1, num_workers=0, ) """============== 加载模型断点("best","latest",通过数字指定) ==============""" model = initialize_config(config["model"]) device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") # device = torch.device("cpu") stft = STFT( filter_length=320, hop_length=160 ).to("cpu") if epoch == "best": model_path = checkpoints_dir / "best_model.tar" model_checkpoint = torch.load(model_path.as_posix(), map_location=device) model_static_dict = model_checkpoint["model"] checkpoint_epoch = model_checkpoint['epoch'] elif epoch == "latest": model_path = checkpoints_dir / "latest_model.tar" model_checkpoint = torch.load(model_path.as_posix(), map_location=device) model_static_dict = model_checkpoint["model"] checkpoint_epoch = model_checkpoint['epoch'] else: model_path = checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth" model_checkpoint = torch.load(model_path.as_posix(), map_location=device) model_static_dict = model_checkpoint checkpoint_epoch = epoch print(f"Loading model checkpoint, epoch = {checkpoint_epoch}") model.load_state_dict(model_static_dict) model.to(device) model.eval() """============== 增强语音 ==============""" if epoch == "best" or epoch == "latest": results_dir = enhancement_dir / f"{epoch}_checkpoint_{checkpoint_epoch}_epoch" else: results_dir = enhancement_dir / f"checkpoint_{epoch}_epoch" results_dir.mkdir(parents=True, exist_ok=True) for i, (mixture, clean, _, names) in enumerate(dataloader): print(f"Enhance {i + 1}th speech") name = names[0] # Mixture mag and Clean mag print("\tSTFT...") mixture_D = stft.transform(mixture) mixture_real = mixture_D[:, :, :, 0] mixture_imag = mixture_D[:, :, :, 1] mixture_mag = torch.sqrt(mixture_real ** 2 + mixture_imag ** 2) # [1, T, F] print("\tEnhancement...") mixture_mag_chunks = torch.split(mixture_mag, mixture_mag.size()[1] // 5, dim=1) mixture_mag_chunks = mixture_mag_chunks[:-1] enhanced_mag_chunks = [] for mixture_mag_chunk in tqdm(mixture_mag_chunks): mixture_mag_chunk = mixture_mag_chunk.to(device) enhanced_mag_chunks.append(model(mixture_mag_chunk).detach().cpu()) # [T, F] enhanced_mag = torch.cat(enhanced_mag_chunks, dim=0).unsqueeze(0) # [1, T, F] # enhanced_mag = enhanced_mag.detach().cpu().data.numpy() # mixture_mag = mixture_mag.cpu() enhanced_real = enhanced_mag * mixture_real[:, :enhanced_mag.size(1), :] / mixture_mag[:, :enhanced_mag.size(1), :] enhanced_imag = enhanced_mag * mixture_imag[:, :enhanced_mag.size(1), :] / mixture_mag[:, :enhanced_mag.size(1), :] enhanced_D = torch.stack([enhanced_real, enhanced_imag], 3) enhanced = stft.inverse(enhanced_D) enhanced = enhanced.detach().cpu().squeeze().numpy() sf.write(f"{results_dir}/{name}.wav", enhanced, 16000)