def get_and_plot_alignments(hp, epoch, attention_graph, sess, attention_inputs, attention_mels, alignment_dir): return_values = sess.run([attention_graph.alignments], # use attention_graph to obtain attention maps for a few given inputs and mels {attention_graph.L: attention_inputs, attention_graph.mels: attention_mels}) alignments = return_values[0] # sess run returns a list, so unpack this list for i in range(hp.num_sentences_to_plot_attention): plot_alignment(hp, alignments[i], i+1, epoch, dir=alignment_dir)
def synthesis(test_lines, model, device, log_dir): global global_epoch global global_step synthesis_dir = os.path.join(log_dir, "synthesis_mels") os.makedirs(synthesis_dir, exist_ok=True) model.eval() with torch.no_grad(): for idx, line in enumerate(test_lines): txt = text_to_seq(line) if device > -1: txt = txt.cuda(device) frames, _, _, alignment = model(txt) dst_alignment_path = join( synthesis_dir, "{}_alignment_{}.png".format(global_step, idx)) dst_mels_path = join(synthesis_dir, "{}_mels_{}.npy".format(global_step, idx)) plot_alignment(alignment.T, dst_alignment_path, info="{}, {}".format(hparams.builder, global_step)) np.save(dst_mels_path, frames)
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1): ''' topoutdir: store samples under here; defaults to hp.sampledir t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models. ''' assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported' dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript fpaths, L = dataset['fpaths'], dataset['texts'] position_in_phone_data = duration_data = labels = None # default if hp.use_external_durations: duration_data = dataset['durations'] if num_sentences > 0: duration_data = duration_data[:num_sentences, :, :] if 'position_in_phone' in hp.history_type: ## TODO: combine + deduplicate with relevant code in train.py for making validation set def duration2position(duration, fractional=False): ### very roundabout -- need to deflate A matrix back to integers: duration = duration.sum(axis=0) #print(duration) # sys.exit('evs') positions = durations_to_position(duration, fractional=fractional) ###positions = end_pad_for_reduction_shape_sync(positions, hp) positions = positions[0::hp.r, :] #print(positions) return positions position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \ for dur in duration_data] position_in_phone_data = list2batch(position_in_phone_data, hp.max_T) # Ensure we aren't trying to generate more utterances than are actually in our test_transcript if num_sentences > 0: assert num_sentences < len(fpaths) L = L[:num_sentences, :] fpaths = fpaths[:num_sentences] bases = [basename(fpath) for fpath in fpaths] if hp.merlin_label_dir: labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \ for fpath in fpaths ] labels = list2batch(labels, hp.max_N) if speaker_id: speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list)))) speaker_ix = speaker2ix[speaker_id] ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph: speaker_data = np.ones((len(L), 1)) * speaker_ix else: speaker_data = None # Load graph ## TODO: generalise to combine other types of models into a synthesis pipeline? g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded") g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ### TODO: specify epoch from comm line? ### TODO: t2m and ssrn from separate configs? if t2m_epoch > -1: restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch) else: t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m') if ssrn_epoch > -1: restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch) else: ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn') # Pass input L through Text2Mel Graph t = start_clock('Text2Mel generating...') ### TODO: after futher efficiency testing, remove this fork if 1: ### efficient route -- only make K&V once ## 3.86, 3.70, 3.80 seconds (2 sentences) text_lengths = get_text_lengths(L) K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels) Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \ speaker_data=speaker_data, duration_data=duration_data, \ position_in_phone_data=position_in_phone_data,\ labels=labels) else: ## 5.68, 5.43, 5.38 seconds (2 sentences) Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \ duration_data=duration_data, \ position_in_phone_data=position_in_phone_data, \ labels=labels) stop_clock(t) ### TODO: useful to test this? # print(Y[0,:,:]) # print (np.isnan(Y).any()) # print('nan1') # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z. t = start_clock('Mel2Mag generating...') Z = synth_mel2mag(hp, Y, g2, sess) stop_clock(t) if (np.isnan(Z).any()): ### TODO: keep? Z = np.nan_to_num(Z) # Generate wav files if not topoutdir: topoutdir = hp.sampledir outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch)) if speaker_id: outdir += '_speaker-%s'%(speaker_id) safe_makedir(outdir) print("Generating wav files, will save to following dir: %s"%(outdir)) assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported' if ncores==1: for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i]*hp.r,:] ### trim to generated length synth_wave(hp, mag, outfile) else: executor = ProcessPoolExecutor(max_workers=ncores) futures = [] for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i]*hp.r,:] ### trim to generated length futures.append(executor.submit(synth_wave, hp, mag, outfile)) proc_list = [future.result() for future in tqdm(futures)] # for i, mag in enumerate(Z): # print("Working on %s"%(bases[i])) # mag = mag[:lengths[i]*hp.r,:] ### trim to generated length # if hp.vocoder=='magphase_compressed': # mag = denorm(mag, s, hp.normtype) # streams = split_streams(mag, ['mag', 'lf0', 'vuv', 'real', 'imag'], [60,1,1,45,45]) # wav = magphase_synth_from_compressed(streams, samplerate=hp.sr) # elif hp.vocoder=='griffin_lim': # wav = spectrogram2wav(hp, mag) # else: # sys.exit('Unsupported vocoder type: %s'%(hp.vocoder)) # #write(outdir + "/{}.wav".format(bases[i]), hp.sr, wav) # soundfile.write(outdir + "/{}.wav".format(bases[i]), wav, hp.sr) # Plot attention alignments for i in range(num_sentences): plot_alignment(hp, alignments[i], utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir)
def main(): # DataSet Loader if args.dataset == "ljspeech": from datasets.ljspeech import LJSpeech # LJSpeech-1.1 dataset loader ljs = LJSpeech( path=cfg.dataset_path, save_to='npy', load_from=None if not os.path.exists(cfg.dataset_path + "/npy") else "npy", verbose=cfg.verbose) else: raise NotImplementedError("[-] Not Implemented Yet...") # Train/Test split tr_size = int(len(ljs) * (1. - cfg.test_size)) tr_text_data, va_text_data = \ ljs.text_data[:tr_size], ljs.text_data[tr_size:] tr_text_len_data, va_text_len_data = \ ljs.text_len_data[:tr_size], ljs.text_len_data[tr_size:] tr_mels, va_mels = ljs.mels[:tr_size], ljs.mels[tr_size:] tr_mags, va_mags = ljs.mags[:tr_size], ljs.mags[tr_size:] del ljs # memory release # Data Iterator di = DataIterator(text=tr_text_data, text_len=tr_text_len_data, mel=tr_mels, mag=tr_mags, batch_size=cfg.batch_size) if cfg.verbose: print("[*] Train/Test split : %d/%d (%.2f/%.2f)" % (tr_text_data.shape[0], va_text_data.shape[0], 1. - cfg.test_size, cfg.test_size)) print(" Train") print("\ttext : ", tr_text_data.shape) print("\ttext_len : ", tr_text_len_data.shape) print("\tmels : ", tr_mels.shape) print("\tmags : ", tr_mags.shape) print(" Test") print("\ttext : ", va_text_data.shape) print("\ttext_len : ", va_text_len_data.shape) print("\tmels : ", va_mels.shape) print("\tmags : ", va_mags.shape) # Model Loading gpu_config = tf.GPUOptions(allow_growth=True) config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_config) with tf.Session(config=config) as sess: if cfg.model == "Tacotron": model = Tacotron(sess=sess, mode=args.mode, sample_rate=cfg.sample_rate, vocab_size=cfg.vocab_size, embed_size=cfg.embed_size, n_mels=cfg.n_mels, n_fft=cfg.n_fft, reduction_factor=cfg.reduction_factor, n_encoder_banks=cfg.n_encoder_banks, n_decoder_banks=cfg.n_decoder_banks, n_highway_blocks=cfg.n_highway_blocks, lr=cfg.lr, lr_decay=cfg.lr_decay, optimizer=cfg.optimizer, grad_clip=cfg.grad_clip, model_path=cfg.model_path) else: raise NotImplementedError("[-] Not Implemented Yet...") if cfg.verbose: print("[*] %s model is loaded!" % cfg.model) # Initializing sess.run(tf.global_variables_initializer()) # Load model & Graph & Weights global_step = 0 ckpt = tf.train.get_checkpoint_state(cfg.model_path) if ckpt and ckpt.model_checkpoint_path: # Restores from checkpoint model.saver.restore(sess, ckpt.model_checkpoint_path) global_step = int( ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) print("[+] global step : %d" % global_step, " successfully loaded") else: print('[-] No checkpoint file found') start_time = time.time() best_loss = np.inf batch_size = cfg.batch_size model.global_step.assign(tf.constant(global_step)) restored_epochs = global_step // (di.text.shape[0] // batch_size) for epoch in range(restored_epochs, cfg.epochs): for text, text_len, mel, mag in di.iterate(): batch_start = time.time() _, y_loss, z_loss = sess.run( [model.train_op, model.y_loss, model.z_loss], feed_dict={ model.x: text, model.x_len: text_len, model.y: mel, model.z: mag, }) batch_end = time.time() if global_step and global_step % cfg.logging_step == 0: va_y_loss, va_z_loss = 0., 0. va_batch = 20 va_iter = len(va_text_data) for idx in range(0, va_iter, va_batch): va_y, va_z = sess.run( [model.y_loss, model.z_loss], feed_dict={ model.x: va_text_data[va_batch * idx:va_batch * (idx + 1)], model.x_len: va_text_len_data[va_batch * idx:va_batch * (idx + 1)], model.y: va_mels[va_batch * idx:va_batch * (idx + 1)], model.z: va_mags[va_batch * idx:va_batch * (idx + 1)], }) va_y_loss += va_y va_z_loss += va_z va_y_loss /= (va_iter // va_batch) va_z_loss /= (va_iter // va_batch) print( "[*] epoch %03d global step %07d [%.03f sec/step]" % (epoch, global_step, (batch_end - batch_start)), " Train \n" " y_loss : {:.6f} z_loss : {:.6f}".format( y_loss, z_loss), " Valid \n" " y_loss : {:.6f} z_loss : {:.6f}".format( va_y_loss, va_z_loss)) # summary summary = sess.run(model.merged, feed_dict={ model.x: va_text_data[:batch_size], model.x_len: va_text_len_data[:batch_size], model.y: va_mels[:batch_size], model.z: va_mags[:batch_size], }) # getting/plotting alignment (important) alignment = sess.run(model.alignments, feed_dict={ model.x: va_text_data[:batch_size], model.x_len: va_text_len_data[:batch_size], model.y: va_mels[:batch_size], }) plot_alignment(alignments=alignment, gs=global_step, path=os.path.join(cfg.model_path, "alignments")) # Summary saver model.writer.add_summary(summary, global_step) # Model save model.saver.save(sess, cfg.model_path + '%s.ckpt' % cfg.model, global_step=global_step) if va_y_loss + va_z_loss < best_loss: model.best_saver.save(sess, cfg.model_path + '%s-best_loss.ckpt' % cfg.model, global_step=global_step) best_loss = va_y_loss + va_z_loss model.global_step.assign_add(tf.constant(1)) global_step += 1 end_time = time.time() print("[+] Training Done! Elapsed {:.8f}s".format(end_time - start_time))
def synthesize(hp, speaker_id='', num_sentences=0, ncores=1, topoutdir='', t2m_epoch=-1, ssrn_epoch=-1): ''' topoutdir: store samples under here; defaults to hp.sampledir t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models. ''' assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported' dataset = load_data(hp, mode="synthesis") #since mode != 'train' or 'validation', will load test_transcript rather than transcript fpaths, L = dataset['fpaths'], dataset['texts'] position_in_phone_data = duration_data = labels = None # default if hp.use_external_durations: duration_data = dataset['durations'] if num_sentences > 0: duration_data = duration_data[:num_sentences, :, :] if 'position_in_phone' in hp.history_type: ## TODO: combine + deduplicate with relevant code in train.py for making validation set def duration2position(duration, fractional=False): ### very roundabout -- need to deflate A matrix back to integers: duration = duration.sum(axis=0) #print(duration) # sys.exit('evs') positions = durations_to_position(duration, fractional=fractional) ###positions = end_pad_for_reduction_shape_sync(positions, hp) positions = positions[0::hp.r, :] #print(positions) return positions position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \ for dur in duration_data] position_in_phone_data = list2batch(position_in_phone_data, hp.max_T) # Ensure we aren't trying to generate more utterances than are actually in our test_transcript if num_sentences > 0: assert num_sentences <= len(fpaths) L = L[:num_sentences, :] fpaths = fpaths[:num_sentences] bases = [basename(fpath) for fpath in fpaths] if hp.merlin_label_dir: labels = [] for fpath in fpaths: label = np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) if hp.select_central: central_ind = get_labels_indices(hp.merlin_lab_dim) label = label[:,central_ind==1] labels.append(label) labels = list2batch(labels, hp.max_N) if speaker_id: speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list)))) speaker_ix = speaker2ix[speaker_id] ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph: speaker_data = np.ones((len(L), 1)) * speaker_ix else: speaker_data = None if hp.turn_off_monotonic_for_synthesis: # if FIA mechanism is turn off text_lengths = get_text_lengths(L) hp.text_lengths = text_lengths + 1 # Load graph ## TODO: generalise to combine other types of models into a synthesis pipeline? g1 = Text2MelGraph(hp, mode="synthesize"); print("Graph 1 (t2m) loaded") if hp.norm == None : t2m_layer_norm = False hp.norm = 'layer' hp.lr = 0.001 hp.beta1 = 0.9 hp.beta2 = 0.999 hp.epsilon = 0.00000001 hp.decay_lr = True hp.batchsize = {'t2m': 32, 'ssrn': 8} else: t2m_layer_norm = True g2 = SSRNGraph(hp, mode="synthesize"); print("Graph 2 (ssrn) loaded") if t2m_layer_norm == False: hp.norm = None hp.lr = 0.0002 hp.beta1 = 0.5 hp.beta2 = 0.9 hp.epsilon = 0.000001 hp.decay_lr = False hp.batchsize = {'t2m': 16, 'ssrn': 8} with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ### TODO: specify epoch from comm line? ### TODO: t2m and ssrn from separate configs? if t2m_epoch > -1: restore_archived_model_parameters(sess, hp, 't2m', t2m_epoch) else: t2m_epoch = restore_latest_model_parameters(sess, hp, 't2m') if ssrn_epoch > -1: restore_archived_model_parameters(sess, hp, 'ssrn', ssrn_epoch) else: ssrn_epoch = restore_latest_model_parameters(sess, hp, 'ssrn') # Pass input L through Text2Mel Graph t = start_clock('Text2Mel generating...') ### TODO: after futher efficiency testing, remove this fork if 1: ### efficient route -- only make K&V once ## 3.86, 3.70, 3.80 seconds (2 sentences) text_lengths = get_text_lengths(L) K, V = encode_text(hp, L, g1, sess, speaker_data=speaker_data, labels=labels) Y, lengths, alignments = synth_codedtext2mel(hp, K, V, text_lengths, g1, sess, \ speaker_data=speaker_data, duration_data=duration_data, \ position_in_phone_data=position_in_phone_data,\ labels=labels) else: ## 5.68, 5.43, 5.38 seconds (2 sentences) Y, lengths = synth_text2mel(hp, L, g1, sess, speaker_data=speaker_data, \ duration_data=duration_data, \ position_in_phone_data=position_in_phone_data, \ labels=labels) stop_clock(t) ### TODO: useful to test this? # print(Y[0,:,:]) # print (np.isnan(Y).any()) # print('nan1') # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z. t = start_clock('Mel2Mag generating...') Z = synth_mel2mag(hp, Y, g2, sess) stop_clock(t) if (np.isnan(Z).any()): ### TODO: keep? Z = np.nan_to_num(Z) # Generate wav files if not topoutdir: topoutdir = hp.sampledir outdir = os.path.join(topoutdir, 't2m%s_ssrn%s'%(t2m_epoch, ssrn_epoch)) if speaker_id: outdir += '_speaker-%s'%(speaker_id) safe_makedir(outdir) # Plot trimmed attention alignment with filename print("Plot attention, will save to following dir: %s"%(outdir)) print("File | CDP | Ain") for i, mag in enumerate(Z): outfile = os.path.join(outdir, bases[i]) trimmed_alignment = alignments[i,:text_lengths[i],:lengths[i]] plot_alignment(hp, trimmed_alignment, utt_idx=i+1, t2m_epoch=t2m_epoch, dir=outdir, outfile=outfile) CDP = getCDP(trimmed_alignment) APin, APout = getAP(trimmed_alignment) print("%s | %.2f | %.2f"%( bases[i], CDP, APin)) print("Generating wav files, will save to following dir: %s"%(outdir)) assert hp.vocoder in ['griffin_lim', 'world'], 'Other vocoders than griffin_lim/world not yet supported' if ncores==1: for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i]*hp.r,:] ### trim to generated length synth_wave(hp, mag, outfile) else: executor = ProcessPoolExecutor(max_workers=ncores) futures = [] for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i]*hp.r,:] ### trim to generated length futures.append(executor.submit(synth_wave, hp, mag, outfile)) proc_list = [future.result() for future in tqdm(futures)]
def synthesize(self, text=None, emo_code=None, mels=None, speaker_id='', num_sentences=0, ncores=1, topoutdir=''): ''' topoutdir: store samples under here; defaults to hp.sampledir t2m_epoch and ssrn_epoch: default -1 means use latest. Otherwise go to archived models. ''' assert self.hp.vocoder in [ 'griffin_lim', 'world' ], 'Other vocoders than griffin_lim/world not yet supported' if text is not None: text_to_phonetic(text=text) dataset = load_data(self.hp, mode='demo') else: dataset = load_data( self.hp, mode="synthesis" ) #since mode != 'train' or 'validation', will load test_transcript rather than transcript fpaths, L = dataset['fpaths'], dataset['texts'] position_in_phone_data = duration_data = labels = None # default if self.hp.use_external_durations: duration_data = dataset['durations'] if num_sentences > 0: duration_data = duration_data[:num_sentences, :, :] if 'position_in_phone' in self.hp.history_type: ## TODO: combine + deduplicate with relevant code in train.py for making validation set def duration2position(duration, fractional=False): ### very roundabout -- need to deflate A matrix back to integers: duration = duration.sum(axis=0) #print(duration) # sys.exit('evs') positions = durations_to_position(duration, fractional=fractional) ###positions = end_pad_for_reduction_shape_sync(positions, hp) positions = positions[0::hp.r, :] #print(positions) return positions position_in_phone_data = [duration2position(dur, fractional=('fractional' in hp.history_type)) \ for dur in duration_data] position_in_phone_data = list2batch(position_in_phone_data, hp.max_T) # Ensure we aren't trying to generate more utterances than are actually in our test_transcript if num_sentences > 0: assert num_sentences < len(fpaths) L = L[:num_sentences, :] fpaths = fpaths[:num_sentences] bases = [basename(fpath) for fpath in fpaths] if self.hp.merlin_label_dir: labels = [np.load("{}/{}".format(hp.merlin_label_dir, basename(fpath)+".npy")) \ for fpath in fpaths ] labels = list2batch(labels, hp.max_N) if speaker_id: speaker2ix = dict(zip(hp.speaker_list, range(len(hp.speaker_list)))) speaker_ix = speaker2ix[speaker_id] ## Speaker codes are held in (batch, 1) matrix -- tiling is done inside the graph: speaker_data = np.ones((len(L), 1)) * speaker_ix else: speaker_data = None # Pass input L through Text2Mel Graph t = start_clock('Text2Mel generating...') ### TODO: after futher efficiency testing, remove this fork if 1: ### efficient route -- only make K&V once ## 3.86, 3.70, 3.80 seconds (2 sentences) text_lengths = get_text_lengths(L) if mels is not None: emo_code = encode_audio2emo(self.hp, mels, self.g1, self.sess) K, V = encode_text(self.hp, L, self.g1, self.sess, emo_mean=emo_code, speaker_data=speaker_data, labels=labels) Y, lengths, alignments = synth_codedtext2mel(self.hp, K, V, text_lengths, self.g1, self.sess, \ speaker_data=speaker_data, duration_data=duration_data, \ position_in_phone_data=position_in_phone_data,\ labels=labels) else: ## 5.68, 5.43, 5.38 seconds (2 sentences) Y, lengths = synth_text2mel(self.hp, L, self.g1, self.sess, speaker_data=speaker_data, \ duration_data=duration_data, \ position_in_phone_data=position_in_phone_data, \ labels=labels) stop_clock(t) ### TODO: useful to test this? # print(Y[0,:,:]) # print (np.isnan(Y).any()) # print('nan1') # Then pass output Y of Text2Mel Graph through SSRN graph to get high res spectrogram Z. t = start_clock('Mel2Mag generating...') Z = synth_mel2mag(self.hp, Y, self.g2, self.sess) stop_clock(t) if (np.isnan(Z).any()): ### TODO: keep? Z = np.nan_to_num(Z) # Generate wav files if not topoutdir: topoutdir = self.hp.sampledir outdir = os.path.join( topoutdir, 't2m%s_ssrn%s' % (self.t2m_epoch, self.ssrn_epoch)) if speaker_id: outdir += '_speaker-%s' % (speaker_id) safe_makedir(outdir) print("Generating wav files, will save to following dir: %s" % (outdir)) assert self.hp.vocoder in [ 'griffin_lim', 'world' ], 'Other vocoders than griffin_lim/world not yet supported' if ncores == 1: for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i] * self.hp.r, :] ### trim to generated length synth_wave(self.hp, mag, outfile) else: executor = ProcessPoolExecutor(max_workers=ncores) futures = [] for i, mag in tqdm(enumerate(Z)): outfile = os.path.join(outdir, bases[i] + '.wav') mag = mag[:lengths[i] * self.hp.r, :] ### trim to generated length futures.append( executor.submit(synth_wave, self.hp, mag, outfile)) proc_list = [future.result() for future in tqdm(futures)] # Plot attention alignments for i in range(num_sentences): plot_alignment(self.hp, alignments[i], utt_idx=i + 1, t2m_epoch=self.t2m_epoch, dir=outdir) self.outdir = outdir
] for idx, sent in enumerate(sentences): wav, attn = eval_model( dv3, sent, replace_pronounciation_prob, min_level_db, ref_level_db, power, n_iter, win_length, hop_length, preemphasis) wav_path = os.path.join( state_dir, "waveform", "eval_sample_{:09d}.wav".format(global_step)) sf.write(wav_path, wav, sample_rate) writer.add_audio( "eval_sample_{}".format(idx), wav, global_step, sample_rate=sample_rate) attn_path = os.path.join( state_dir, "alignments", "eval_sample_attn_{:09d}.png".format(global_step)) plot_alignment(attn, attn_path) writer.add_image( "eval_sample_attn{}".format(idx), cm.viridis(attn), global_step, dataformats="HWC") # save checkpoint if global_step % save_interval == 0: io.save_parameters(ckpt_dir, global_step, dv3, optim) global_step += 1
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(STYLER()).to(device) print("Model Has Been Defined") # Parameters num_param = utils.get_param_num(model) text_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.text_encoder) audio_encoder = utils.get_param_num( model.module.style_modeling.style_encoder.audio_encoder) predictors = utils.get_param_num(model.module.style_modeling.duration_predictor)\ + utils.get_param_num(model.module.style_modeling.pitch_predictor)\ + utils.get_param_num(model.module.style_modeling.energy_predictor) decoder = utils.get_param_num(model.module.decoder) print('Number of Model Parameters :', num_param) print('Number of Text Encoder Parameters :', text_encoder) print('Number of Audio Encoder Parameters :', audio_encoder) print('Number of Predictor Parameters :', predictors) print('Number of Decoder Parameters :', decoder) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = STYLERLoss().to(device) DATLoss = DomainAdversarialTrainingLoss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path()) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder vocoder = utils.get_vocoder() # Init logger log_path = hp.log_path() if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path() if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i*hp.batch_size + j + args.restore_step + \ epoch*len(loader)*hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) mel_aug = torch.from_numpy( data_of_batch["mel_aug"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) f0_norm = torch.from_numpy( data_of_batch["f0_norm"]).float().to(device) f0_norm_aug = torch.from_numpy( data_of_batch["f0_norm_aug"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) energy_input = torch.from_numpy( data_of_batch["energy_input"]).float().to(device) energy_input_aug = torch.from_numpy( data_of_batch["energy_input_aug"]).float().to(device) speaker_embed = torch.from_numpy( data_of_batch["speaker_embed"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_outputs, mel_postnet_outputs, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _, aug_posteriors = model( text, mel_target, mel_aug, f0_norm, energy_input, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, speaker_embed=speaker_embed) # Cal Loss Clean mel_output, mel_postnet_output = mel_outputs[ 0], mel_postnet_outputs[0] mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss, classifier_loss_a = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask, src_len, mel_len,\ aug_posteriors, torch.zeros(mel_target.size(0)).long().to(device)) # Cal Loss Noisy mel_output_noisy, mel_postnet_output_noisy = mel_outputs[ 1], mel_postnet_outputs[1] mel_noisy_loss, mel_postnet_noisy_loss = Loss.cal_mel_loss( mel_output_noisy, mel_postnet_output_noisy, mel_aug, ~mel_mask) # Forward DAT enc_cat = model.module.style_modeling.style_encoder.encoder_input_cat( mel_aug, f0_norm_aug, energy_input_aug, mel_aug) duration_encoding, pitch_encoding, energy_encoding, _ = model.module.style_modeling.style_encoder.audio_encoder( enc_cat, mel_len, src_len, mask=None) aug_posterior_d = model.module.style_modeling.augmentation_classifier_d( duration_encoding) aug_posterior_p = model.module.style_modeling.augmentation_classifier_p( pitch_encoding) aug_posterior_e = model.module.style_modeling.augmentation_classifier_e( energy_encoding) # Cal Loss DAT classifier_loss_a_dat = DATLoss( (aug_posterior_d, aug_posterior_p, aug_posterior_e), torch.ones(mel_target.size(0)).long().to(device)) # Total loss total_loss = mel_loss + mel_postnet_loss + mel_noisy_loss + mel_postnet_noisy_loss + d_loss + f_loss + e_loss\ + hp.dat_weight*(classifier_loss_a + classifier_loss_a_dat) # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() m_n_l = mel_noisy_loss.item() m_p_n_l = mel_postnet_noisy_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() cl_a = classifier_loss_a.item() cl_a_dat = classifier_loss_a_dat.item() # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step == 1 or current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) train_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) train_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) train_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step == 1 or current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_aug_torch = mel_aug[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_aug = mel_aug[0, :length].detach().cpu().transpose( 0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel_noisy_torch = mel_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_noisy = mel_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet_noisy_torch = mel_postnet_output_noisy[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) mel_postnet_noisy = mel_postnet_output_noisy[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "c"))) # Audio.tools.inv_mel_spec(mel_noisy, os.path.join( # synth_path, "step_{}_{}_griffin_lim.wav".format(current_step, "n"))) # Audio.tools.inv_mel_spec(mel_postnet_noisy, os.path.join( # synth_path, "step_{}_{}_postnet_griffin_lim.wav".format(current_step, "n"))) wav_mel = utils.vocoder_infer( mel_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "c", hp.vocoder))) wav_mel_postnet = utils.vocoder_infer( mel_postnet_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "c", hp.vocoder))) wav_ground_truth = utils.vocoder_infer( mel_target_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "c", hp.vocoder))) wav_mel_noisy = utils.vocoder_infer( mel_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_{}.wav'.format(current_step, "n", hp.vocoder))) wav_mel_postnet_noisy = utils.vocoder_infer( mel_postnet_noisy_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_postnet_{}.wav'.format( current_step, "n", hp.vocoder))) wav_aug = utils.vocoder_infer( mel_aug_torch, vocoder, os.path.join( hp.synth_path(), 'step_{}_{}_ground-truth_{}.wav'.format( current_step, "n", hp.vocoder))) # Model duration prediction log_duration_output = log_duration_output[ 0, :src_len[0].item()].detach().cpu() # [seg_len] log_duration_output = torch.clamp(torch.round( torch.exp(log_duration_output) - hp.log_offset), min=0).int() model_duration = utils.get_alignment_2D( log_duration_output).T # [seg_len, mel_len] model_duration = utils.plot_alignment([model_duration]) # Model mel prediction f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() mel_predicted = utils.plot_data( [(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)], [ 'Synthetized Spectrogram Clean', 'Ground-Truth Spectrogram' ], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "c"))) mel_noisy_predicted = utils.plot_data( [(mel_postnet_noisy.numpy(), f0_output, energy_output), (mel_aug.numpy(), f0, energy)], ['Synthetized Spectrogram Noisy', 'Aug Spectrogram'], filename=os.path.join( synth_path, 'step_{}_{}.png'.format(current_step, "n"))) # Normalize audio for tensorboard logger. See https://github.com/lanpa/tensorboardX/issues/511#issuecomment-537600045 wav_ground_truth = wav_ground_truth / max(wav_ground_truth) wav_mel = wav_mel / max(wav_mel) wav_mel_postnet = wav_mel_postnet / max(wav_mel_postnet) wav_aug = wav_aug / max(wav_aug) wav_mel_noisy = wav_mel_noisy / max(wav_mel_noisy) wav_mel_postnet_noisy = wav_mel_postnet_noisy / max( wav_mel_postnet_noisy) train_logger.add_image("model_duration", model_duration, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Clean", mel_predicted, current_step, dataformats='HWC') train_logger.add_image("mel_predicted/Noisy", mel_noisy_predicted, current_step, dataformats='HWC') train_logger.add_audio("Clean/wav_ground_truth", wav_ground_truth, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel", wav_mel, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Clean/wav_mel_postnet", wav_mel_postnet, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_aug", wav_aug, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_noisy", wav_mel_noisy, current_step, sample_rate=hp.sampling_rate) train_logger.add_audio("Noisy/wav_mel_postnet_noisy", wav_mel_postnet_noisy, current_step, sample_rate=hp.sampling_rate) if current_step == 1 or current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, cl_a, cl_a_dat, m_l, m_p_l, m_n_l, m_p_n_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l + m_n_l + m_p_n_l\ + hp.dat_weight*(cl_a + cl_a_dat) val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/mel_noisy_loss', m_n_l, current_step) val_logger.add_scalar('Loss/mel_postnet_noisy_loss', m_p_n_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) val_logger.add_scalar('Loss/dat_clean_loss', cl_a, current_step) val_logger.add_scalar('Loss/dat_noisy_loss', cl_a_dat, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def train(train_loader, model, device, mels_criterion, stop_criterion, optimizer, scheduler, writer, train_dir): batch_time = ValueWindow() data_time = ValueWindow() losses = ValueWindow() # switch to train mode model.train() end = time.time() global global_epoch global global_step for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(train_loader): scheduler.adjust_learning_rate(optimizer, global_step) # measure data loading time data_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss #print(frames_loss, decoder_frames_loss) losses.update(loss.item()) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if hparams.clip_thresh > 0: grad_norm = torch.nn.utils.clip_grad_norm_( model.get_trainable_parameters(), hparams.clip_thresh) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) if hparams.clip_thresh > 0: writer.add_scalar("gradient norm", grad_norm, global_step) writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'], global_step) global_step += 1 dst_alignment_path = join(train_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step))
def validate(val_loader, model, device, mels_criterion, stop_criterion, writer, val_dir): batch_time = ValueWindow() losses = ValueWindow() # switch to evaluate mode model.eval() global global_epoch global global_step with torch.no_grad(): end = time.time() for i, (txts, mels, stop_tokens, txt_lengths, mels_lengths) in enumerate(val_loader): # measure data loading time batch_time.update(time.time() - end) if device > -1: txts = txts.cuda(device) mels = mels.cuda(device) stop_tokens = stop_tokens.cuda(device) txt_lengths = txt_lengths.cuda(device) mels_lengths = mels_lengths.cuda(device) # compute output frames, decoder_frames, stop_tokens_predict, alignment = model( txts, txt_lengths, mels) decoder_frames_loss = mels_criterion(decoder_frames, mels, lengths=mels_lengths) frames_loss = mels_criterion(frames, mels, lengths=mels_lengths) stop_token_loss = stop_criterion(stop_tokens_predict, stop_tokens, lengths=mels_lengths) loss = decoder_frames_loss + frames_loss + stop_token_loss losses.update(loss.item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % hparams.print_freq == 0: log('Epoch: [{0}]\t' 'Test: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( global_epoch, i, len(val_loader), batch_time=batch_time, loss=losses)) # Logs writer.add_scalar("loss", float(loss.item()), global_step) writer.add_scalar( "avg_loss in {} window".format(losses.get_dinwow_size), float(losses.avg), global_step) writer.add_scalar("stop_token_loss", float(stop_token_loss.item()), global_step) writer.add_scalar("decoder_frames_loss", float(decoder_frames_loss.item()), global_step) writer.add_scalar("output_frames_loss", float(frames_loss.item()), global_step) dst_alignment_path = join(val_dir, "{}_alignment.png".format(global_step)) alignment = alignment.cpu().detach().numpy() plot_alignment(alignment[0, :txt_lengths[0], :mels_lengths[0]], dst_alignment_path, info="{}, {}".format(hparams.builder, global_step)) return losses.avg
ref_level_db = transform_config["ref_level_db"] preemphasis = transform_config["preemphasis"] win_length = transform_config["win_length"] hop_length = transform_config["hop_length"] synthesis_config = config["synthesis"] power = synthesis_config["power"] n_iter = synthesis_config["n_iter"] synthesis_dir = os.path.join(args.output, "synthesis") if not os.path.exists(synthesis_dir): os.makedirs(synthesis_dir) with open(args.text, "rt", encoding="utf-8") as f: lines = f.readlines() for idx, line in enumerate(lines): text = line[:-1] dv3.eval() wav, attn = eval_model(dv3, text, replace_pronounciation_prob, min_level_db, ref_level_db, power, n_iter, win_length, hop_length, preemphasis) plot_alignment( attn, os.path.join(synthesis_dir, "test_{}_step_{}.png".format(idx, iteration))) sf.write( os.path.join(synthesis_dir, "test_{}_step{}.wav".format(idx, iteration)), wav, sample_rate)