def main(): testset = TextDataset(args.testset) test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False, shuffle=False, collate_fn=synth_collate_fn, pin_memory=True) t2m = Text2Mel().to(DEVICE) ssrn = SSRN().to(DEVICE) mname = type(t2m).__name__ ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar'))) state = torch.load(ckpt[-1]) t2m.load_state_dict(state['model']) args.global_step = state['global_step'] mname = type(ssrn).__name__ ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar'))) state = torch.load(ckpt[-1]) ssrn.load_state_dict(state['model']) print('All of models are loaded.') t2m.eval() ssrn.eval() if not os.path.exists(os.path.join(args.sampledir, 'A')): os.makedirs(os.path.join(args.sampledir, 'A')) synthesize(t2m, ssrn, test_loader, args.test_batch)
from playsound import playsound import numpy as np import torch from num2words import num2words from hparams import HParams as hp from audio import save_to_wav from models import SSRN,Text2Mel from lj_speech import vocab, idx2char, get_test_data torch.set_grad_enabled(False) text2mel = Text2Mel(vocab) text2mel.load_state_dict(torch.load("ljspeech-text2mel.pth").state_dict()) text2mel = text2mel.eval() ssrn = SSRN() ssrn.load_state_dict(torch.load("ljspeech-ssrn.pth").state_dict()) ssrn = ssrn.eval() def say(sentence): new_sentence=" " .join([num2words(w) if w.isdigit() else w for w in sentence.split()]) normalized_sentence = "".join([c if c.lower() in vocab else '' for c in new_sentence]) print(normalized_sentence) sentences = [normalized_sentence] max_N = len(normalized_sentence) L = torch.from_numpy(get_test_data(sentences, max_N)) zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)) Y = zeros A = None for t in range(hp.max_T): _, Y_t, A = text2mel(L, Y, monotonic_attention=True)