def main(): testset = TextDataset(args.testset) test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False, shuffle=False, collate_fn=synth_collate_fn, pin_memory=True) t2m = Text2Mel().to(DEVICE) ssrn = SSRN().to(DEVICE) mname = type(t2m).__name__ ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar'))) state = torch.load(ckpt[-1]) t2m.load_state_dict(state['model']) args.global_step = state['global_step'] mname = type(ssrn).__name__ ckpt = sorted(glob.glob(os.path.join(args.logdir, mname, '*k.pth.tar'))) state = torch.load(ckpt[-1]) ssrn.load_state_dict(state['model']) print('All of models are loaded.') t2m.eval() ssrn.eval() if not os.path.exists(os.path.join(args.sampledir, 'A')): os.makedirs(os.path.join(args.sampledir, 'A')) synthesize(t2m, ssrn, test_loader, args.test_batch)
def main(network): if network == 'text2mel': model = Text2Mel().to(DEVICE) elif network == 'ssrn': model = SSRN().to(DEVICE) else: print('Wrong network. {text2mel, ssrn}') return print('Model {} is working...'.format(type(model).__name__)) print('{} threads are used...'.format(torch.get_num_threads())) ckpt_dir = os.path.join(args.logdir, type(model).__name__) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = MultiStepLR(optimizer, milestones=[50000, 150000, 300000], gamma=0.5) # if not os.path.exists(ckpt_dir): os.makedirs(os.path.join(ckpt_dir, 'A', 'train')) else: print('Already exists. Retrain the model.') ckpt = sorted(glob.glob(os.path.join(ckpt_dir, '*k.pth.tar')))[-1] state = torch.load(ckpt) model.load_state_dict(state['model']) args.global_step = state['global_step'] optimizer.load_state_dict(state['optimizer']) # scheduler.load_state_dict(state['scheduler']) # model = torch.nn.DataParallel(model, device_ids=list(range(args.no_gpu))).to(DEVICE) if type(model).__name__ == 'Text2Mel': if args.ga_mode: cfn_train, cfn_eval = t2m_ga_collate_fn, t2m_collate_fn else: cfn_train, cfn_eval = t2m_collate_fn, t2m_collate_fn else: cfn_train, cfn_eval = collate_fn, collate_fn dataset = SpeechDataset(args.data_path, args.meta_train, type(model).__name__, mem_mode=args.mem_mode, ga_mode=args.ga_mode) validset = SpeechDataset(args.data_path, args.meta_eval, type(model).__name__, mem_mode=args.mem_mode) data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, collate_fn=cfn_train, drop_last=True, pin_memory=True) valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch, shuffle=False, collate_fn=cfn_eval, pin_memory=True) writer = SummaryWriter(ckpt_dir) train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer) return None
mode='valid') else: if args.dataset == 'ljspeech': from datasets.lj_speech import vocab, LJSpeech as SpeechDataset elif args.dataset == 'mbspeech': from datasets.mb_speech import vocab, MBSpeech as SpeechDataset train_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset( ['texts', 'mels', 'mel_gates']), batch_size=64, mode='train') valid_data_loader = Text2MelDataLoader(text2mel_dataset=SpeechDataset( ['texts', 'mels', 'mel_gates']), batch_size=64, mode='valid') text2mel = Text2Mel(vocab).cuda() """ if args.warmstart: old_lr = hp.text2mel_lr hp.text2mel_lr = hp.text2mel_lr / 10.0 print("Reducing learning rate from %.9f to %.9f because of warmstart" % (old_lr, hp.text2mel_lr)) """ optimizer = torch.optim.Adam(text2mel.parameters(), lr=hp.text2mel_lr) start_timestamp = int(time.time() * 1000) start_epoch = 0 global_step = 0 logger = Logger(args.dataset, 'text2mel')
import warnings warnings.filterwarnings("ignore") from playsound import playsound import numpy as np import torch from num2words import num2words from hparams import HParams as hp from audio import save_to_wav from models import SSRN,Text2Mel from lj_speech import vocab, idx2char, get_test_data torch.set_grad_enabled(False) text2mel = Text2Mel(vocab) text2mel.load_state_dict(torch.load("ljspeech-text2mel.pth").state_dict()) text2mel = text2mel.eval() ssrn = SSRN() ssrn.load_state_dict(torch.load("ljspeech-ssrn.pth").state_dict()) ssrn = ssrn.eval() def say(sentence): new_sentence=" " .join([num2words(w) if w.isdigit() else w for w in sentence.split()]) normalized_sentence = "".join([c if c.lower() in vocab else '' for c in new_sentence]) print(normalized_sentence) sentences = [normalized_sentence] max_N = len(normalized_sentence) L = torch.from_numpy(get_test_data(sentences, max_N)) zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)) Y = zeros A = None
SENTENCES = [ "Нийслэлийн прокурорын газраас төрийн өндөр албан тушаалтнуудад холбогдох зарим эрүүгийн хэргүүдийг шүүхэд шилжүүлэв.", "Мөнх тэнгэрийн хүчин дор Монгол Улс цэцэглэн хөгжих болтугай.", "Унасан хүлгээ түрүү магнай, аман хүзүүнд уралдуулж, айрагдуулсан унаач хүүхдүүдэд бэлэг гардууллаа.", "Албан ёсоор хэлэхэд “Монгол Улсын хэрэг эрхлэх газрын гэгээнтэн” гэж нэрлээд байгаа зүйл огт байхгүй.", "Сайн чанарын бохирын хоолой зарна.", "Хараа тэглэх мэс заслын дараа хараа дахин муудах магадлал бага.", "Ер нь бол хараа тэглэх мэс заслыг гоо сайхны мэс засалтай адилхан гэж зүйрлэж болно.", "Хашлага даван, зүлэг гэмтээсэн жолоочийн эрхийг хоёр жилээр хасжээ.", "Монгол хүн бидний сэтгэлийг сорсон орон. Энэ бол миний төрсөн нутаг. Монголын сайхан орон.", "Постройка крейсера затягивалась из-за проектных неувязок, необходимости." ] torch.set_grad_enabled(False) text2mel = Text2Mel(vocab).eval() last_checkpoint_file_name = get_last_checkpoint_file_name( os.path.join(hp.logdir, '%s-text2mel' % args.dataset)) # last_checkpoint_file_name = 'logdir/%s-text2mel/step-020K.pth' % args.dataset if last_checkpoint_file_name: print("loading text2mel checkpoint '%s'..." % last_checkpoint_file_name) load_checkpoint(last_checkpoint_file_name, text2mel, None) else: print("text2mel not exits") sys.exit(1) ssrn = SSRN().eval() last_checkpoint_file_name = get_last_checkpoint_file_name( os.path.join(hp.logdir, '%s-ssrn' % args.dataset)) # last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset if last_checkpoint_file_name:
load_checkpoint('trained/ssrn/lj/step-140K.pth', ssrn, None) # last_checkpoint_file_name = get_last_checkpoint_file_name(os.path.join(hp.logdir, '%s-ssrn' % args.dataset)) # last_checkpoint_file_name = 'logdir/%s-ssrn/step-005K.pth' % args.dataset # if last_checkpoint_file_name: # print("loading ssrn checkpoint '%s'..." % last_checkpoint_file_name) # load_checkpoint(last_checkpoint_file_name, ssrn, None) # else: # print("ssrn not exits") # sys.exit(1) if not os.path.isdir(f'samples'): os.mkdir(f'samples') for t2m in t2m_list: filename = os.path.splitext(os.path.basename(t2m))[0] folder = os.path.split(os.path.split(t2m)[0])[-1] text2mel = Text2Mel(vocab).to(device).eval() print("loading text2mel...") load_checkpoint(t2m, text2mel, None) # text2mel = Text2Mel(vocab) # text2mel.load_state_dict(torch.load(t2m).state_dict()) # text2mel = text2mel.eval() for sentence in SENTENCES: with torch.no_grad(): L = torch.from_numpy(get_test_data(sentence)).to(device) zeros = torch.from_numpy(np.zeros((1, hp.n_mels, 1), np.float32)).to(device) Y = zeros # A = None while True: _, Y_t, A = text2mel(L, Y, monotonic_attention=True) Y = torch.cat((zeros, Y_t), -1)