Exemple #1
0
def main(cmd_args):
    """Run training."""
    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)
    args = parser.parse_args(cmd_args)

    if os.path.exists(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
    else:
        print("Checkpoint not exixts")
        return None

    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint["hp_str"])

    validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)
    print("Checkpoint : ", args.checkpoint_path)

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = FeedForwardTransformer(idim, odim, hp)
    # os.makedirs(args.out, exist_ok=True)
    checkpoint = torch.load(args.checkpoint_path)
    model.load_state_dict(checkpoint["model"])

    evaluate(hp, validloader, model)
Exemple #2
0
def synthesis(args, text, hp):
    """Decode with E2E-TTS model."""
    set_deterministic_pytorch(args)
    # read training config
    idim = hp.symbol_len
    odim = hp.num_mels
    model = FeedForwardTransformer(idim, odim, hp)
    print(model)

    if os.path.exists(args.path):
        print("\nSynthesis Session...\n")
        model.load_state_dict(torch.load(args.path), strict=False)
    else:
        print("Checkpoint not exixts")
        return None

    model.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    input = np.asarray(phonemes_to_sequence(text.split()))
    text = torch.LongTensor(input)
    text = text.cuda()
    # [num_char]

    with torch.no_grad():
        # decode and write
        idx = input[:5]
        start_time = time.time()
        print("text :", text.size())
        outs, probs, att_ws = model.inference(text, hp)
        print("Out size : ", outs.size())

        logging.info("inference speed = %s msec / frame." %
                     ((time.time() - start_time) / (int(outs.size(0)) * 1000)))
        if outs.size(0) == text.size(0) * args.maxlenratio:
            logging.warning("output length reaches maximum length .")

        print("mels", outs.size())
        mel = outs.cpu().numpy()  # [T_out, num_mel]
        print("numpy ", mel.shape)

        return mel
Exemple #3
0
def synthesis_tts(args, text, path):
    """Decode with E2E-TTS model."""
    set_deterministic_pytorch(args)
    print("TTS synthesis")
    # read training config
    idim = hp.symbol_len
    odim = hp.num_mels
    print("Text :", text)
    input = np.asarray(phonemes_to_sequence(text.split()))
    print("Input :", input)
    model = FeedForwardTransformer(idim, odim)

    if os.path.exists(path):
        logging.info('\nSynthesis Session...\n')
        model.load_state_dict(torch.load(path), strict=False)
    else:
        logging.info("Checkpoint not exixts")
        return None

    model.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    text = torch.LongTensor(input)
    text = text.cuda()
    #[num_char]

        # define function for plot prob and att_ws
    def _plot_and_save(array, figname, figsize=(6, 4), dpi=150):
        import matplotlib.pyplot as plt
        shape = array.shape
        if len(shape) == 1:
            # for eos probability
            plt.figure(figsize=figsize, dpi=dpi)
            plt.plot(array)
            plt.xlabel("Frame")
            plt.ylabel("Probability")
            plt.ylim([0, 1])
        elif len(shape) == 2:
            # for tacotron 2 attention weights, whose shape is (out_length, in_length)
            plt.figure(figsize=figsize, dpi=dpi)
            plt.imshow(array, aspect="auto")
            plt.xlabel("Input")
            plt.ylabel("Output")
        elif len(shape) == 4:
            # for transformer attention weights, whose shape is (#leyers, #heads, out_length, in_length)
            plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi)
            for idx1, xs in enumerate(array):
                for idx2, x in enumerate(xs, 1):
                    plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2)
                    plt.imshow(x, aspect="auto")
                    plt.xlabel("Input")
                    plt.ylabel("Output")
        else:
            raise NotImplementedError("Support only from 1D to 4D array.")
        plt.tight_layout()
        if not os.path.exists(os.path.dirname(figname)):
            # NOTE: exist_ok = True is needed for parallel process decoding
            os.makedirs(os.path.dirname(figname), exist_ok=True)
        plt.savefig(figname)
        plt.close()

    with torch.no_grad():
        # decode and write
        idx = input[:5]
        start_time = time.time()
        print("predicting")
        outs, probs, att_ws = model.inference(text, args)

        logging.info("inference speed = %s msec / frame." % (
            (time.time() - start_time) / (int(outs.size(0)) * 1000)))
        if outs.size(0) == text.size(0) * 5:
            logging.warning("output length reaches maximum length .")

        mel = outs#.cpu().numpy() # [T_out, num_mel]

        

        return mel
Exemple #4
0
def main(args):
    """Run deocding."""
    para_mel = []
    parser = get_parser()
    args = parser.parse_args(args)

    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    print("Text : ", args.text)
    print("Checkpoint : ", args.checkpoint_path)
    if os.path.exists(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
    else:
        logging.info("Checkpoint not exixts")
        return None

    if args.config is not None:
        hp = HParam(args.config)
    else:
        hp = load_hparam_str(checkpoint["hp_str"])

    idim = len(valid_symbols)
    odim = hp.audio.num_mels
    model = FeedForwardTransformer(
        idim, odim, hp)  # torch.jit.load("./etc/fastspeech_scrip_new.pt")

    os.makedirs(args.out, exist_ok=True)
    if args.old_model:
        logging.info("\nSynthesis Session...\n")
        model.load_state_dict(checkpoint, strict=False)
    else:
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])

    text = process_paragraph(args.text)

    for i in range(0, len(text)):
        txt = preprocess(text[i])
        audio = synth(txt, model, hp)
        m = audio.T
        para_mel.append(m)

    m = torch.cat(para_mel, dim=1)
    np.save("mel.npy", m.cpu().numpy())
    plot_mel(m)

    if hp.train.melgan_vocoder:
        m = m.unsqueeze(0)
        print("Mel shape: ", m.shape)
        vocoder = torch.hub.load("seungwonpark/melgan", "melgan")
        vocoder.eval()
        if torch.cuda.is_available():
            vocoder = vocoder.cuda()
            mel = m.cuda()

        with torch.no_grad():
            wav = vocoder.inference(
                mel)  # mel ---> batch, num_mels, frames [1, 80, 234]
            wav = wav.cpu().float().numpy()
    else:
        stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
        print(m.size())
        m = m.unsqueeze(0)
        wav = griffin_lim(m, stft, 30)
        wav = wav.cpu().numpy()
    save_path = "{}/test_tts.wav".format(args.out)
    write(save_path, hp.audio.sample_rate, wav.astype("int16"))