Ejemplo n.º 1
0
    def __init__(self,
                 melgan,
                 filter_length=1024,
                 n_overlap=4,
                 win_length=1024,
                 mode='zeros',
                 device="cuda"):
        super(Denoiser, self).__init__()
        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length / n_overlap),
                         win_length=win_length).to(device)
        if mode == 'zeros':
            mel_input = torch.zeros((1, 80, 88)).to(device)
        elif mode == 'normal':
            mel_input = torch.randn((1, 80, 88)).to(device)
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            mel_input = mel_input.to(device)
            bias_audio = melgan.inference(mel_input).float()  # [B, 1, T]

            bias_spec, _ = self.stft.transform(bias_audio.squeeze(0))

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
        self.device = device
Ejemplo n.º 2
0
class Denoiser(torch.nn.Module):
    """ Removes model bias from audio produced with waveglow """
    def __init__(self,
                 melgan,
                 filter_length=1024,
                 n_overlap=4,
                 win_length=1024,
                 mode='zeros'):
        super(Denoiser, self).__init__()
        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length / n_overlap),
                         win_length=win_length).cuda()
        if mode == 'zeros':
            mel_input = torch.zeros((1, 80, 88)).cuda()
        elif mode == 'normal':
            mel_input = torch.randn((1, 80, 88)).cuda()
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            bias_audio = melgan.inference(mel_input).float()  # [B, 1, T]

            bias_spec, _ = self.stft.transform(bias_audio.squeeze(0))

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])

    def forward(self, audio, strength=0.1):
        audio_spec, audio_angles = self.stft.transform(audio.cuda().float())
        audio_spec_denoised = audio_spec.cuda() - self.bias_spec * strength
        audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
        audio_denoised = self.stft.inverse(audio_spec_denoised,
                                           audio_angles.cuda())
        return audio_denoised
Ejemplo n.º 3
0
    def __init__(self,
                 waveflow,
                 filter_length=1024,
                 n_overlap=4,
                 win_length=1024,
                 mode='zeros',
                 half=False,
                 device=torch.device('cuda')):
        super(Denoiser, self).__init__()
        self.device = device
        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length / n_overlap),
                         win_length=win_length).cuda()
        if mode == 'zeros':
            mel_input = torch.zeros((1, 80, 88)).to(device)
        elif mode == 'normal':
            mel_input = torch.randn((1, 80, 88)).to(device)
        else:
            raise Exception("Mode {} if not supported".format(mode))
        if half:
            mel_input = mel_input.half()
        with torch.no_grad():
            bias_audio, _ = waveflow.infer(mel_input)  # [B, 1, T]
            bias_spec, _ = self.stft.transform(bias_audio.unsqueeze(0).float())

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
Ejemplo n.º 4
0
 def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
              n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
              mel_fmax=None):
     super(WaveGlowSTFT, self).__init__()
     self.n_mel_channels = n_mel_channels
     self.sampling_rate = sampling_rate
     self.stft_fn = STFT(filter_length, hop_length, win_length)
     mel_basis = librosa_mel_fn(
         sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
     mel_basis = torch.from_numpy(mel_basis).float()
     self.register_buffer('mel_basis', mel_basis)
class Resynthesizer(object):
    def __init__(self, device, win_size=320, hop_size=160):
        self.stft = STFT(win_size, hop_size).to(device)

    def __call__(self, est, mix):
        real_mix, imag_mix = self.stft.stft(mix)
        pha_mix = torch.atan2(imag_mix.data, real_mix.data)
        real_est = est * torch.cos(pha_mix)
        imag_est = est * torch.sin(pha_mix)
        sph_est = self.stft.istft(torch.stack([real_est, imag_est], dim=1))
        sph_est = F.pad(sph_est, [0, mix.shape[1]-sph_est.shape[1]])

        return sph_est
class NetFeeder(object):
    def __init__(self, device, win_size=320, hop_size=160):
        self.eps = torch.finfo(torch.float32).eps
        self.stft = STFT(win_size, hop_size).to(device)

    def __call__(self, mix, sph):
        real_mix, imag_mix = self.stft.stft(mix)
        feat = torch.stack([real_mix, imag_mix], dim=1)
        
        real_sph, imag_sph = self.stft.stft(sph)
        lbl = torch.stack([real_sph, imag_sph], dim=1)

        return feat, lbl
class NetFeeder(object):
    def __init__(self, device, win_size=320, hop_size=160):
        self.eps = torch.finfo(torch.float32).eps
        self.stft = STFT(win_size, hop_size).to(device)

    def __call__(self, mix, sph):
        real_mix, imag_mix = self.stft.stft(mix)
        mag_mix = torch.sqrt(real_mix**2 + imag_mix**2)
        feat = mag_mix
        
        real_sph, imag_sph = self.stft.stft(sph)
        mag_sph = torch.sqrt(real_sph**2 + imag_sph**2)
        lbl = mag_sph

        return feat, lbl
Ejemplo n.º 8
0
def infer(text):
    args = sys.argv[1:]
    parser = get_parser_tts()
    args = parser.parse_args(args)

    # display PYTHONPATH
    logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))

    print("Text : ", text)
    audio = synthesis_tts(args, args.text, args.path)
    m = audio.T
    if hp.melgan_vocoder:
        m = m.unsqueeze(0)
        vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
        vocoder.eval()
        if torch.cuda.is_available():
            vocoder = vocoder.cuda()
            mel = m.cuda()

        with torch.no_grad():
            wav = vocoder.inference(mel) # mel ---> batch, num_mels, frames [1, 80, 234]
            wav = wav.cpu().numpy()
    else:
        stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
        print(m.size())
        m = m.unsqueeze(0)
        wav = griffin_lim(m, stft, 30)
        wav = wav.cpu().numpy()
    save_path = '{}/test_tts.wav'.format(args.out)
    save_wav(wav, save_path)
    return save_path
Ejemplo n.º 9
0
def main(args):
    """Run deocding."""
    parser = get_parser()
    args = parser.parse_args(args)

    # display PYTHONPATH
    logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))

    print("Text : ", args.text)
    print("Checkpoint : ", args.path)
    audio = synthesis_tts(args, args.text, args.path)
    m = audio.T
    
    np.save("mel.npy", m.cpu().numpy())
    if hp.melgan_vocoder:
        m = m.unsqueeze(0)
        print("Mel shape: ",m.shape)
        vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
        vocoder.eval()
        if torch.cuda.is_available():
            vocoder = vocoder.cuda()
            mel = m.cuda()

        with torch.no_grad():
            wav = vocoder.inference(mel)  # mel ---> batch, num_mels, frames [1, 80, 234]
            wav = wav.cpu().float().numpy()
    else:
        stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
        print(m.size())
        m = m.unsqueeze(0)
        wav = griffin_lim(m, stft, 30)
        wav = wav.cpu().numpy()
    save_path = '{}/test_tts.wav'.format(args.out)
    write(save_path, hp.sample_rate, wav.astype('int16'))
class Resynthesizer(object):
    def __init__(self, device, win_size=320, hop_size=160):
        self.stft = STFT(win_size, hop_size).to(device)

    def __call__(self, est, mix):
        sph_est = self.stft.istft(est)
        sph_est = F.pad(sph_est, [0, mix.shape[1]-sph_est.shape[1]])

        return sph_est
    def __init__(self, config, resume: bool, model, optimizer, loss_function):
        self.n_gpu = config["n_gpu"]
        self.device = self._prepare_device(self.n_gpu, config["use_cudnn"])

        self.model = model.to(self.device)
        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=list(
                                                   range(self.n_gpu)))

        self.optimizer = optimizer
        self.loss_function = loss_function

        # Feature
        self.stft = STFT(filter_length=320, hop_length=160).to(self.device)

        # Trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"][
            "save_checkpoint_interval"]
        self.validation_interval = config["trainer"]["validation_interval"]
        self.find_max = config["trainer"]["find_max"]
        self.z_score = config["trainer"]["z_score"]

        self.start_epoch = 1  # Not in the config file, will be update if resume is True
        self.best_score = 0.0 if self.find_max else 100  # Not in the config file, will be update in training and if resume is True
        self.root_dir = (Path(config["save_location"]) /
                         config["experiment_name"]).expanduser().absolute()
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume)

        self.viz = TensorboardWriter(self.logs_dir.as_posix())
        self.viz.writer.add_text("Config",
                                 json.dumps(config, indent=2, sort_keys=False),
                                 global_step=1)
        self.viz.writer.add_text("Description",
                                 config["description"],
                                 global_step=1)

        if resume: self._resume_checkpoint()

        print("Model, optimizer, parameters and directories initialized.")
        print("Configurations are as follows: ")
        print(json.dumps(config, indent=2, sort_keys=False))

        config_save_path = (self.root_dir / "config.json").as_posix()
        with open(config_save_path, "w") as handle:
            json.dump(config, handle, indent=2, sort_keys=False)
        self._print_networks([self.model])
Ejemplo n.º 12
0
    def __init__(self, melgan, pqmf=None, filter_length=1024, n_overlap=4,
                 win_length=1024, mode='zeros'):
        super(Denoiser, self).__init__()
        self.stft = STFT(filter_length=filter_length,
                         hop_length=int(filter_length/n_overlap),
                         win_length=win_length).cuda()
        if mode == 'zeros':
            mel_input = torch.zeros(
                (1, 80, 88)).cuda()
        elif mode == 'normal':
            mel_input = torch.randn(
                (1, 80, 88)).cuda()
        else:
            raise Exception("Mode {} if not supported".format(mode))

        with torch.no_grad():
            bias_audio = melgan.inference(mel_input).float() # [B, 1, T]

            # For multi-band inference
            if pqmf:
                bias_audio = pqmf.synthesis(bias_audio).view(-1)
            bias_spec, _ = self.stft.transform(bias_audio.unsqueeze(0))

        self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None])
Ejemplo n.º 13
0
    def __init__(self, channels: int, nfft: int, hop: int,
                 activation: str) -> None:
        """
        Argumentos:
            channels -- Número de canales de audio
            nfft -- Número de puntos para calcular la nfft
            hop -- Número de puntos de hop
            activation -- Función de activación a utilizar
        """
        super(BlendNet, self).__init__()
        self.channels = channels
        self.nfft = nfft
        self.bins = self.nfft // 2 + 1
        self.hop = hop
        blend = 2

        self.stft = STFT(self.nfft, self.hop)

        self.conv_stft = nn.Sequential(
            STFTConvLayer(features=self.bins,
                          in_channels=blend * self.channels,
                          out_channels=8),
            STFTConvLayer(features=(self.bins - 2) // 2, in_channels=8),
            STFTConvLayer(features=(self.bins - 6) // 4, in_channels=16),
            STFTConvLayer(features=(self.bins - 14) // 8, in_channels=32),
            STFTConvLayer(features=(self.bins - 30) // 16, in_channels=64)
        )  # h_out = (h_in - 62) // 32, w_out = w_in, out_channels = 128

        self.linear_stft = nn.Linear(in_features=(self.bins - 62) // 32 * 128,
                                     out_features=blend * self.bins *
                                     self.channels)

        self.conv_wave = nn.Sequential(
            WaveConvLayer(in_channels=(blend + 1) * self.channels,
                          out_channels=8), WaveConvLayer(in_channels=8),
            WaveConvLayer(in_channels=16), WaveConvLayer(in_channels=32),
            WaveConvLayer(in_channels=64))

        self.linear_wave = nn.Linear(in_features=128,
                                     out_features=(blend + 1) * self.channels)

        if activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "tanh":
            self.activation = nn.Tanh()
        else:
            raise NotImplementedError
Ejemplo n.º 14
0
class TacotronSTFT(torch.nn.Module):
    def __init__(self,
                 filter_length=1024,
                 hop_length=256,
                 win_length=1024,
                 n_mel_channels=80,
                 sampling_rate=44800,
                 mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(sampling_rate, filter_length,
                                   n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        assert (torch.min(y.data) >= -1)
        assert (torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output
Ejemplo n.º 15
0
    def __init__(self, n_channels: int, hidden_size: int, num_layers: int,
                 dropout: float, n_fft: int, hop: int) -> None:
        """
        Argumentos:
            n_channels -- Número de canales de audio
            hidden_size -- Cantidad de unidades en cada capa BLSTM
            num_layers -- Cantidad de capas BLSTM
            dropout -- Dropout de las capas BLSTM
            n_fft -- Tamaño de la fft para el espectrograma
            hop -- Tamaño del hop del espectrograma
        """
        super(SpectrogramModel, self).__init__()

        n_bins = n_fft // 2 + 1
        self.n_fft = n_fft
        self.hop = hop
        self.stft = STFT(n_fft, hop)
        self.batch_norm = BatchNorm(n_bins)
        self.blstm = BLSTM(n_channels * n_bins, hidden_size, num_layers, dropout)
        self.mask = Mask(n_bins, 2 * hidden_size, n_channels)
def main(config, epoch):
    root_dir = Path(config["experiments_dir"]) / config["name"]
    enhancement_dir = root_dir / "enhancements"
    checkpoints_dir = root_dir / "checkpoints"

    """============== 加载数据集 =============="""
    dataset = initialize_config(config["dataset"])
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=1,
        num_workers=0,
    )

    """============== 加载模型断点("best","latest",通过数字指定) =============="""
    model = initialize_config(config["model"])
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    # device = torch.device("cpu")
    stft = STFT(
        filter_length=320,
        hop_length=160
    ).to("cpu")

    if epoch == "best":
        model_path = checkpoints_dir / "best_model.tar"
        model_checkpoint = torch.load(model_path.as_posix(), map_location=device)
        model_static_dict = model_checkpoint["model"]
        checkpoint_epoch = model_checkpoint['epoch']
    elif epoch == "latest":
        model_path = checkpoints_dir / "latest_model.tar"
        model_checkpoint = torch.load(model_path.as_posix(), map_location=device)
        model_static_dict = model_checkpoint["model"]
        checkpoint_epoch = model_checkpoint['epoch']
    else:
        model_path = checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth"
        model_checkpoint = torch.load(model_path.as_posix(), map_location=device)
        model_static_dict = model_checkpoint
        checkpoint_epoch = epoch

    print(f"Loading model checkpoint, epoch = {checkpoint_epoch}")
    model.load_state_dict(model_static_dict)
    model.to(device)
    model.eval()

    """============== 增强语音 =============="""
    if epoch == "best" or epoch == "latest":
        results_dir = enhancement_dir / f"{epoch}_checkpoint_{checkpoint_epoch}_epoch"
    else:
        results_dir = enhancement_dir / f"checkpoint_{epoch}_epoch"

    results_dir.mkdir(parents=True, exist_ok=True)

    for i, (mixture, clean, _, names) in enumerate(dataloader):
        print(f"Enhance {i + 1}th speech")
        name = names[0]

        # Mixture mag and Clean mag
        print("\tSTFT...")
        mixture_D = stft.transform(mixture)
        mixture_real = mixture_D[:, :, :, 0]
        mixture_imag = mixture_D[:, :, :, 1]
        mixture_mag = torch.sqrt(mixture_real ** 2 + mixture_imag ** 2) # [1, T, F]

        print("\tEnhancement...")
        mixture_mag_chunks = torch.split(mixture_mag, mixture_mag.size()[1] // 5, dim=1)
        mixture_mag_chunks = mixture_mag_chunks[:-1]
        enhanced_mag_chunks = []
        for mixture_mag_chunk in tqdm(mixture_mag_chunks):
            mixture_mag_chunk = mixture_mag_chunk.to(device)
            enhanced_mag_chunks.append(model(mixture_mag_chunk).detach().cpu()) # [T, F]

        enhanced_mag = torch.cat(enhanced_mag_chunks, dim=0).unsqueeze(0) # [1, T, F]

        # enhanced_mag = enhanced_mag.detach().cpu().data.numpy()
        # mixture_mag = mixture_mag.cpu()

        enhanced_real = enhanced_mag * mixture_real[:, :enhanced_mag.size(1), :] / mixture_mag[:, :enhanced_mag.size(1), :]
        enhanced_imag = enhanced_mag * mixture_imag[:, :enhanced_mag.size(1), :] / mixture_mag[:, :enhanced_mag.size(1), :]

        enhanced_D = torch.stack([enhanced_real, enhanced_imag], 3)
        enhanced = stft.inverse(enhanced_D)

        enhanced = enhanced.detach().cpu().squeeze().numpy()

        sf.write(f"{results_dir}/{name}.wav", enhanced, 16000)
 def __init__(self, device, win_size=320, hop_size=160):
     self.stft = STFT(win_size, hop_size).to(device)
 def __init__(self, device, win_size=320, hop_size=160):
     self.eps = torch.finfo(torch.float32).eps
     self.stft = STFT(win_size, hop_size).to(device)
Ejemplo n.º 19
0
def main(args):
    """Run deocding."""
    para_mel = []
    parser = get_parser()
    args = parser.parse_args(args)

    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    print("Text : ", args.text)
    print("Checkpoint : ", args.checkpoint_path)
    if os.path.exists(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
    else:
        logging.info("Checkpoint not exixts")
        return None

    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint["hp_str"])

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = FeedForwardTransformer(
        idim, odim, hp)  # torch.jit.load("./etc/fastspeech_scrip_new.pt")

    os.makedirs(args.out, exist_ok=True)
    if args.old_model:
        logging.info("\nSynthesis Session...\n")
        model.load_state_dict(checkpoint, strict=False)
    else:
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])

    text = process_paragraph(args.text)

    for i in range(0, len(text)):
        txt = preprocess(text[i])
        audio = synth(txt, model, hp)
        m = audio.T
        para_mel.append(m)

    m = torch.cat(para_mel, dim=1)
    np.save("mel.npy", m.cpu().numpy())
    plot_mel(m)

    if hp.train.melgan_vocoder:
        m = m.unsqueeze(0)
        print("Mel shape: ", m.shape)
        vocoder = torch.hub.load("seungwonpark/melgan", "melgan")
        vocoder.eval()
        if torch.cuda.is_available():
            vocoder = vocoder.cuda()
            mel = m.cuda()

        with torch.no_grad():
            wav = vocoder.inference(
                mel)  # mel ---> batch, num_mels, frames [1, 80, 234]
            wav = wav.cpu().float().numpy()
    else:
        stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
        print(m.size())
        m = m.unsqueeze(0)
        wav = griffin_lim(m, stft, 30)
        wav = wav.cpu().numpy()
    save_path = "{}/test_tts.wav".format(args.out)
    write(save_path, hp.audio.sample_rate, wav.astype("int16"))