Пример #1
0
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = ModifiedGenerator(hp.audio.n_mel_channels, hp.model.n_residual_layers,
                        ratios=hp.model.generator_ratio, mult = hp.model.mult,
                        out_band = hp.model.out_channels).cuda()
    model.load_state_dict(checkpoint['model_g'])
    model.eval(inference=True)

    with torch.no_grad():
        mel = torch.from_numpy(np.load(args.input))
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.cuda()
        audio = model.inference(mel)

        audio = audio.squeeze(0)  # collapse all dimension except time axis
        if args.d:
            denoiser = Denoiser(model).cuda()
            audio = denoiser(audio, 0.01)
        audio = audio.squeeze()
        audio = audio[:-(hp.audio.hop_length*10)]
        audio = MAX_WAV_VALUE * audio
        audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
        audio = audio.short()
        audio = audio.cpu().detach().numpy()

        out_path = args.input.replace('.npy', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
        write(out_path, hp.audio.sampling_rate, audio)
Пример #2
0
def main(cmd_args):
    """Run training."""
    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)
    args = parser.parse_args(cmd_args)

    if os.path.exists(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
    else:
        print("Checkpoint not exixts")
        return None

    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint["hp_str"])

    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)
    print("Checkpoint : ", args.checkpoint_path)

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = FeedForwardTransformer(idim, odim, hp)
    # os.makedirs(args.out, exist_ok=True)
    checkpoint = torch.load(args.checkpoint_path)
    model.load_state_dict(checkpoint["model"])

    evaluate(hp, validloader, model)
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = Generator(hp.audio.n_mel_channels).cuda()
    model.load_state_dict(checkpoint['model_g'])
    model.eval(inference=False)

    with torch.no_grad():
        for melpath in tqdm.tqdm(
                glob.glob(os.path.join(args.input_folder, '*.mel'))):
            mel = torch.load(melpath)
            if len(mel.shape) == 2:
                mel = mel.unsqueeze(0)
            mel = mel.cuda()

            audio = model.inference(mel)
            audio = audio.cpu().detach().numpy()

            out_path = melpath.replace(
                '.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
            write(out_path, hp.audio.sampling_rate, audio)
Пример #4
0
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = Generator(hp.audio.n_mel_channels).cuda()
    model.load_state_dict(checkpoint['model_g'])
    model.eval()

    with torch.no_grad():
        for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
            mel = torch.load(melpath)
            if len(mel.shape) == 2:
                mel = mel.unsqueeze(0)
            mel = mel.cuda()

            # pad input mel with zeros to cut artifact
            # see https://github.com/seungwonpark/melgan/issues/8
            zero = torch.full((1, hp.audio.n_mel_channels, 10), -11.5129).cuda()
            mel = torch.cat((mel, zero), axis=2)

            audio = model(mel)
            audio = audio.squeeze() # collapse all dimension except time axis
            audio = audio[:-(hp.audio.hop_length*10)]
            audio = MAX_WAV_VALUE * audio
            audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE)
            audio = audio.short()
            audio = audio.cpu().detach().numpy()

            out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
            write(out_path, hp.audio.sampling_rate, audio)
Пример #5
0
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = Generator(hp.audio.n_mel_channels).cuda()

    model.load_state_dict(checkpoint['model_g'])
    model.eval()

    with torch.no_grad():
        mel = torch.from_numpy(np.load(args.input))
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.cuda()
        audio = model(mel)
        # For multi-band inference
        print(audio.shape)
        audio = audio.squeeze(0)  # collapse all dimension except time axis
        if args.d:
            denoiser = Denoiser(model).cuda()
            audio = denoiser(audio, 0.1)
        audio = audio.squeeze()
        audio = audio[:-(hp.audio.hop_length * 10)]
        audio = MAX_WAV_VALUE * audio
        audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE - 1)
        audio = audio.short()
        audio = audio.cpu().detach().numpy()

        out_path = args.input.replace(
            '.npy', '_hifi_GAN_epoch%04d.wav' % checkpoint['epoch'])
        write(out_path, hp.audio.sampling_rate, audio)
Пример #6
0
    def load_tiers(self):
        for idx, chkpt_path in enumerate(self.infer_hp.checkpoints):
            checkpoint = torch.load(chkpt_path)
            hp = load_hparam_str(checkpoint['hp_str'])

            if self.hp != hp:
                print('Warning: hp different in file %s' % chkpt_path)
            
            self.tiers[idx+1].load_state_dict(checkpoint['model'])
Пример #7
0
def init(config, checkpoint_path, device="cuda"):
    checkpoint = torch.load(checkpoint_path)
    if config is not None:
        hp = HParam(config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = Generator(hp.audio.n_mel_channels,
                      hp.model.n_residual_layers,
                      ratios=hp.model.generator_ratio,
                      mult=hp.model.mult,
                      out_band=hp.model.out_channels).to(device)
    model.load_state_dict(checkpoint['model_g'])
    model.eval(inference=True)
    return hp, model
Пример #8
0
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = ModifiedGenerator(hp.audio.n_mel_channels, hp.model.n_residual_layers,
                        ratios=hp.model.generator_ratio, mult = hp.model.mult,
                        out_band = hp.model.out_channels).cuda()
    model.load_state_dict(checkpoint['model_g'])
    model.eval(inference=True)

    with torch.no_grad():
        mel = torch.from_numpy(np.load(args.input))
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.cuda()
        zero = torch.full((1, 80, 10), -11.5129).to(mel.device)
        mel = torch.cat((mel, zero), dim=2)
        vocgan_trace = torch.jit.trace(model, mel)
        vocgan_trace.save("{}/vocgan_ex_female_en_{}_{}.pt".format(args.out, checkpoint['githash'], checkpoint['epoch']))
Пример #9
0
def main(args):
    checkpoint = torch.load(args.checkpoint_path)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])

    model = Generator(hp.audio.n_mel_channels).cuda()

    model.load_state_dict(checkpoint['model_g'])
    model.eval()
    #model.remove_weight_norm()

    with torch.no_grad():
        mel = torch.from_numpy(np.load(args.input))
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.cuda()
        #zero = torch.full((1, 80, 10), -11.5129).to(mel.device)
        #mel = torch.cat((mel, zero), dim=2)
        hifigan_trace = torch.jit.trace(model, mel)
        #print(state_dict_g.keys())
        hifigan_trace.save("{}/hifigan_{}.pt".format(args.out, args.name))
Пример #10
0
def main(args):
    """Run deocding."""
    para_mel = []
    parser = get_parser()
    args = parser.parse_args(args)

    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    print("Text : ", args.text)
    print("Checkpoint : ", args.checkpoint_path)
    if os.path.exists(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
    else:
        logging.info("Checkpoint not exixts")
        return None

    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint["hp_str"])

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = FeedForwardTransformer(
        idim, odim, hp)  # torch.jit.load("./etc/fastspeech_scrip_new.pt")

    os.makedirs(args.out, exist_ok=True)
    if args.old_model:
        logging.info("\nSynthesis Session...\n")
        model.load_state_dict(checkpoint, strict=False)
    else:
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])

    text = process_paragraph(args.text)

    for i in range(0, len(text)):
        txt = preprocess(text[i])
        audio = synth(txt, model, hp)
        m = audio.T
        para_mel.append(m)

    m = torch.cat(para_mel, dim=1)
    np.save("mel.npy", m.cpu().numpy())
    plot_mel(m)

    if hp.train.melgan_vocoder:
        m = m.unsqueeze(0)
        print("Mel shape: ", m.shape)
        vocoder = torch.hub.load("seungwonpark/melgan", "melgan")
        vocoder.eval()
        if torch.cuda.is_available():
            vocoder = vocoder.cuda()
            mel = m.cuda()

        with torch.no_grad():
            wav = vocoder.inference(
                mel)  # mel ---> batch, num_mels, frames [1, 80, 234]
            wav = wav.cpu().float().numpy()
    else:
        stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
        print(m.size())
        m = m.unsqueeze(0)
        wav = griffin_lim(m, stft, 30)
        wav = wav.cpu().numpy()
    save_path = "{}/test_tts.wav".format(args.out)
    write(save_path, hp.audio.sample_rate, wav.astype("int16"))
Пример #11
0
        type=str,
        default=None,
        help=
        "yaml file for config. will use hp_str from checkpoint if not given.")
    parser.add_argument('-p',
                        '--chkpt',
                        default=None,
                        type=str,
                        help='path to latest checkpoint (default: None)')
    parser.add_argument('-d', action='store_true', help="denoising ")
    parser.add_argument('--cpu', action='store_true')

    args = parser.parse_args()
    torch.manual_seed(2020)

    checkpoint = torch.load(args.chkpt)
    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint['hp_str'])
    filename = "out"
    if args.cpu:
        filename = filename + "_" + "cpu"
        device = torch.device('cpu')
    else:
        filename = filename + "_" + "cuda"
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Mel file input : ", args.mel)
    main(hp, checkpoint, args.infile, args.out, filename, args.sigma,
         args.duration, args.half, args.mel, device)