def synthesis(text_list, plot=True): device = "cuda:2" # load graphs graph0 = Text2Mel() graph1 = SuperRes() load(os.path.join(Hyper.logdir, "text2mel/pkg/trained.pkg"), graph0, device='cpu') load(os.path.join(Hyper.logdir, "superres/pkg/trained.pkg"), graph1, device='cpu') graph0.eval() graph1.eval() # make dir syn_dir = os.path.join(Hyper.root_dir, "synthesis") if not os.path.exists(syn_dir): os.makedirs(syn_dir) # phase1: text to mel graph0.to(device) texts = [process_text(text, padding=True) for text in text_list] texts = torch.LongTensor(np.asarray(texts)).to(device) mels = torch.FloatTensor( np.zeros( (len(texts), Hyper.dim_f, Hyper.data_max_mel_length))).to(device) prev_atten = None bar = PrettyBar(Hyper.data_max_mel_length - 1) bar.set_description("Text to Mel") for t in bar: _, new_mel = graph0(texts, mels, None if t == 0 else t - 1, prev_atten) mels[:, :, t + 1].data.copy_(new_mel[:, :, t].data) prev_atten = graph0.attention for i in range(len(text_list)): # mels[:, :, :-1].data.copy_(mels[:, :, 1:].data) if plot: plot_attention(graph0.attention[i].cpu().data, "atten", i, True, syn_dir) plot_spectrum(mels[i].cpu().data, "mels", i, True, syn_dir) del graph0 # phase2: super resolution graph1.to(device) _, mags = graph1(mels) bar = PrettyBar(len(text_list)) bar.set_description("Super Resolution") for i in bar: wav = spectrogram2wav(mags[i].cpu().data.numpy()) wavfile.write(os.path.join(syn_dir, "syn_{}.wav".format(i)), Hyper.audio_samplerate, wav) if plot: plot_spectrum(mags[i].cpu().data, "mags", i, True, syn_dir) del graph1
def train_text2mel(load_trained): # create log dir logdir = os.path.join(Hyper.logdir, "text2mel") if not os.path.exists(logdir): os.makedirs(logdir) if not os.path.exists(os.path.join(logdir, "pkg")): os.mkdir(os.path.join(logdir, "pkg")) # device device = Hyper.device_text2mel graph = Text2Mel().to(device) # set the training flag graph.train() # load data and get batch maker names, lengths, texts = load_data() batch_maker = BatchMaker(Hyper.batch_size, names, lengths, texts) criterion_mels = nn.L1Loss().to(device) criterion_bd1 = nn.BCEWithLogitsLoss().to(device) criterion_atten = nn.L1Loss().to(device) optimizer = torch.optim.Adam(graph.parameters(), lr=Hyper.adam_alpha, betas=Hyper.adam_betas, eps=Hyper.adam_eps) lossplot_mels = LogHelper("mel_l1", logdir) lossplot_bd1 = LogHelper("mel_BCE", logdir) lossplot_atten = LogHelper("atten", logdir) dynamic_guide = float(Hyper.guide_weight) global_step = 0 # check if load if load_trained > 0: print("load model trained for {}k batches".format(load_trained)) global_step = load( os.path.join(logdir, "pkg/save_{}k.pkg".format(load_trained)), graph, { "mels": criterion_mels, "bd1": criterion_bd1, "atten": criterion_atten }, optimizer) dynamic_guide *= Hyper.guide_decay**(load_trained * 1000) for loop_cnt in range( int(Hyper.num_batches / batch_maker.num_batches() + 0.5)): print("loop", loop_cnt) bar = PrettyBar(batch_maker.num_batches()) bar.set_description("training...") loss_str0 = MovingAverage() loss_str1 = MovingAverage() loss_str2 = MovingAverage() for bi in bar: batch = batch_maker.next_batch() # make batch texts = torch.LongTensor(batch["texts"]).to(device) # shift mel shift_mels = torch.FloatTensor( np.concatenate((np.zeros( (batch["mels"].shape[0], batch["mels"].shape[1], 1)), batch["mels"][:, :, :-1]), axis=2)).to(device) # ground truth mels = torch.FloatTensor(batch["mels"]).to(device) # forward pred_logits, pred_mels = graph(texts, shift_mels) # loss if False: loss_mels = sum( criterion_mels( torch.narrow(pred_mels[i], -1, 0, batch["mel_lengths"] [i]), torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i])) for i in range(batch_maker.batch_size())) / float( batch_maker.batch_size()) loss_bd1 = sum( criterion_bd1( torch.narrow(pred_logits[i], -1, 0, batch["mel_lengths"][i]), torch.narrow(mels[i], -1, 0, batch["mel_lengths"][i])) for i in range(batch_maker.batch_size())) / float( batch_maker.batch_size()) else: loss_mels = criterion_mels(pred_mels, mels) loss_bd1 = criterion_bd1(pred_logits, mels) # guide attention atten_guide = torch.FloatTensor(batch["atten_guides"]).to(device) atten_mask = torch.FloatTensor(batch["atten_masks"]).to(device) atten_mask = torch.ones_like(graph.attention) loss_atten = criterion_atten( atten_guide * graph.attention * atten_mask, torch.zeros_like(graph.attention)) * dynamic_guide loss = loss_mels + loss_bd1 + loss_atten # backward graph.zero_grad() optimizer.zero_grad() loss.backward() # clip grad nn.utils.clip_grad_value_(graph.parameters(), 1) optimizer.step() # log loss_str0.add(loss_mels.cpu().data.mean()) loss_str1.add(loss_bd1.cpu().data.mean()) loss_str2.add(loss_atten.cpu().data.mean()) lossplot_mels.add(loss_str0(), global_step) lossplot_bd1.add(loss_str1(), global_step) lossplot_atten.add(loss_str2(), global_step) # adjust dynamic_guide # dynamic_guide = float((loss_mels + loss_bd1).cpu().data.mean() / loss_atten.cpu().data.mean()) dynamic_guide *= Hyper.guide_decay if dynamic_guide < Hyper.guide_lowbound: dynamic_guide = Hyper.guide_lowbound bar.set_description( "gs: {}, mels: {}, bd1: {}, atten: {}, scale: {}".format( global_step, loss_str0(), loss_str1(), loss_str2(), "%4f" % dynamic_guide)) # plot if global_step % 100 == 0: gs = 0 plot_spectrum(mels[0].cpu().data, "mel_true", gs, dir=logdir) plot_spectrum(shift_mels[0].cpu().data, "mel_input", gs, dir=logdir) plot_spectrum(pred_mels[0].cpu().data, "mel_pred", gs, dir=logdir) plot_spectrum(graph.query[0].cpu().data, "query", gs, dir=logdir) plot_attention(graph.attention[0].cpu().data, "atten", gs, True, dir=logdir) plot_attention((atten_guide)[0].cpu().data, "atten_guide", gs, True, dir=logdir) if global_step % 500 == 0: lossplot_mels.plot() lossplot_bd1.plot() lossplot_atten.plot() if global_step % 10000 == 0: save( os.path.join(logdir, "pkg/save_{}k.pkg").format( global_step // 1000), graph, { "mels": criterion_mels, "bd1": criterion_bd1, "atten": criterion_atten }, optimizer, global_step, True) # increase global step global_step += 1
def synthesis(text_list, plot=True): info = {} absolute_beginning = time.time() info["start_time"] = absolute_beginning Latency_beginning = None device = "cpu" # "cuda:0" # load graphs graph0 = Text2Mel() graph1 = SuperRes() load(os.path.join(Hyper.logdir, "text2mel/pkg/trained.pkg"), graph0, device='cpu') load(os.path.join(Hyper.logdir, "superres/pkg/trained.pkg"), graph1, device='cpu') graph0.eval() graph1.eval() # make dir syn_dir = os.path.join(Hyper.root_dir, "synthesis") if not os.path.exists(syn_dir): os.makedirs(syn_dir) # phase1: text to mel graph0.to(device) texts = [process_text(text, padding=True) for text in text_list] texts = torch.LongTensor(np.asarray(texts)).to(device) mels = torch.FloatTensor( np.zeros( (len(texts), Hyper.dim_f, Hyper.data_max_mel_length))).to(device) prev_atten = None bar = PrettyBar(Hyper.data_max_mel_length - 1) bar.set_description("Text to Mel") begin_of_frame_synthesis = time.time() init_time = begin_of_frame_synthesis - absolute_beginning for t in bar: _, new_mel = graph0(texts, mels, None if t == 0 else t - 1, prev_atten) mels[:, :, t + 1].data.copy_(new_mel[:, :, t].data) prev_atten = graph0.attention if Latency_beginning == None: # from the start time of the synthesis until the first time we reach this point is called the Latency Latency_beginning = time.time() - absolute_beginning Latency_synthesis = time.time() - begin_of_frame_synthesis duration_mels = time.time() - begin_of_frame_synthesis info_mels = {} for i, mel in enumerate(mels): info_mels["syn_{}.wav".format(i)] = mel for i in range(len(text_list)): # mels[:, :, :-1].data.copy_(mels[:, :, 1:].data) if plot: plot_attention(graph0.attention[i].cpu().data, "atten", i, True, syn_dir) plot_spectrum(mels[i].cpu().data, "mels", i, True, syn_dir) del graph0 # phase2: super resolution graph1.to(device) _, mags = graph1(mels) duration_mags = time.time() - begin_of_frame_synthesis bar = PrettyBar(len(text_list)) bar.set_description("Super Resolution") info_samples = {} info_mags = {} for i in bar: info_mags["syn_{}.wav".format(i)] = mags[i] wav = spectrogram2wav(mags[i].cpu().data.numpy()) info_samples["syn_{}.wav".format(i)] = wav wavfile.write(os.path.join(syn_dir, "syn_{}.wav".format(i)), Hyper.audio_samplerate, wav) if plot: plot_spectrum(mags[i].cpu().data, "mags", i, True, syn_dir) del graph1 duration_total = time.time() - begin_of_frame_synthesis time_measurents = { "init_time": init_time, "Latency_beginning": Latency_beginning, "Latency_synthesis": Latency_synthesis, "duration_mels": duration_mels, "duration_mags": duration_mags, "duration_total": duration_total, } info["mels"] = info_mels info["mags"] = info_mags info["samples"] = info_samples info["time_measurements"] = time_measurents return info