def train(args, hp, hp_str, logger, vocoder): os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True) device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size, hp) validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True) idim = len(valid_symbols) odim = hp.audio.num_mels model = fastspeech.FeedForwardTransformer(idim, odim, hp) # set torch device model = model.to(device) print("Model is loaded ...") githash = get_commit_hash() if args.checkpoint_path is not None: if os.path.exists(args.checkpoint_path): logger.info("Resuming from checkpoint: %s" % args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint["model"]) optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) optimizer.load_state_dict(checkpoint["optim"]) global_step = checkpoint["step"] if hp_str != checkpoint["hp_str"]: logger.warning( "New hparams is different from checkpoint. Will use new.") if githash != checkpoint["githash"]: logger.warning( "Code might be different: git hash is different.") logger.warning("%s -> %s" % (checkpoint["githash"], githash)) else: print("Checkpoint does not exixts") global_step = 0 return None else: print("New Training") global_step = 0 optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) print("Batch Size :", hp.train.batch_size) num_params(model) os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True) writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name)) model.train() forward_count = 0 # print(model) for epoch in range(hp.train.epochs): start = time.time() running_loss = 0 j = 0 pbar = tqdm.tqdm(dataloader, desc="Loading train data") for data in pbar: global_step += 1 x, input_length, y, _, out_length, _, dur, e, p = data # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] loss, report_dict = model( x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda(), dur.cuda(), e.cuda(), p.cuda(), ) loss = loss.mean() / hp.train.accum_grad running_loss += loss.item() loss.backward() # update parameters forward_count += 1 j = j + 1 if forward_count != hp.train.accum_grad: continue forward_count = 0 step = global_step # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.train.grad_clip) logging.debug("grad norm={}".format(grad_norm)) if math.isnan(grad_norm): logging.warning("grad norm is nan. Do not update model.") else: optimizer.step() optimizer.zero_grad() if step % hp.train.summary_interval == 0: pbar.set_description( "Average Loss %.04f Loss %.04f | step %d" % (running_loss / j, loss.item(), step)) for r in report_dict: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("main/{}".format(k), v, step) if step % hp.train.validation_step == 0: for valid in validloader: x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid model.eval() with torch.no_grad(): loss_, report_dict_ = model( x_.cuda(), input_length_.cuda(), y_.cuda(), out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda(), ) mels_ = model.inference(x_[-1].cuda()) # [T, num_mel] model.train() for r in report_dict_: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("validation/{}".format(k), v, step) mels_ = mels_.T # Out: [num_mels, T] writer.add_image( "melspectrogram_target_{}".format(ids_[-1]), plot_spectrogram_to_numpy( y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]), step, dataformats="HWC", ) writer.add_image( "melspectrogram_prediction_{}".format(ids_[-1]), plot_spectrogram_to_numpy(mels_.data.cpu().numpy()), step, dataformats="HWC", ) # print(mels.unsqueeze(0).shape) audio = generate_audio( mels_.unsqueeze(0), vocoder ) # selecting the last data point to match mel generated above audio = audio.cpu().float().numpy() audio = audio / (audio.max() - audio.min() ) # get values between -1 and 1 writer.add_audio( tag="generated_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(audio), global_step=step, sample_rate=hp.audio.sample_rate, ) _, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate, ) writer.add_audio( tag=" target_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(target), global_step=step, sample_rate=hp.audio.sample_rate, ) ## if step % hp.train.save_interval == 0: avg_p, avg_e, avg_d = evaluate(hp, validloader, model) writer.add_scalar("evaluation/Pitch Loss", avg_p, step) writer.add_scalar("evaluation/Energy Loss", avg_e, step) writer.add_scalar("evaluation/Dur Loss", avg_d, step) save_path = os.path.join( hp.train.chkpt_dir, args.name, "{}_fastspeech_{}_{}k_steps.pyt".format( args.name, githash, step // 1000), ) torch.save( { "model": model.state_dict(), "optim": optimizer.state_dict(), "step": step, "hp_str": hp_str, "githash": githash, }, save_path, ) logger.info("Saved checkpoint to: %s" % save_path) print("Time taken for epoch {} is {} sec\n".format( epoch + 1, int(time.time() - start)))
def train(args): os.makedirs(hp.chkpt_dir, exist_ok=True) os.makedirs(args.outdir, exist_ok=True) os.makedirs(os.path.join(args.outdir, 'img'), exist_ok=True) device = torch.device("cuda" if hp.ngpu > 0 else "cpu") dataloader = loader.get_tts_dataset(hp.data_dir, hp.batch_size) validloader = loader.get_tts_dataset(hp.data_dir, 5, True) global_step = 0 idim = hp.symbol_len odim = hp.num_mels model = fastspeech.FeedForwardTransformer(idim, odim) # set torch device model = model.to(device) print("Model is loaded ...") if args.resume is not None: if os.path.exists(args.resume): print('\nSynthesis Session...\n') model.load_state_dict(torch.load(args.resume), strict=False) optimizer = get_std_opt(model, hp.adim, hp.transformer_warmup_steps, hp.transformer_lr) optimizer.load_state_dict( torch.load(args.resume.replace("model", "optim"))) global_step = hp.accum_grad * optimizer._step else: print("Checkpoint not exixts") return None else: optimizer = get_std_opt(model, hp.adim, hp.transformer_warmup_steps, hp.transformer_lr) print("Batch Size :", hp.batch_size) num_params(model) writer = SummaryWriter(hp.log_dir) model.train() forward_count = 0 print(model) for epoch in range(hp.epochs): start = time.time() running_loss = 0 j = 0 pbar = tqdm.tqdm(dataloader, desc='Loading train data') for data in pbar: global_step += 1 x, input_length, y, _, out_length, _, dur, e, p = data # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] loss, report_dict = model(x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda(), dur.cuda(), e.cuda(), p.cuda()) loss = loss.mean() / hp.accum_grad running_loss += loss.item() loss.backward() # update parameters forward_count += 1 j = j + 1 if forward_count != hp.accum_grad: continue forward_count = 0 step = global_step # # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip) logging.debug('grad norm={}'.format(grad_norm)) if math.isnan(grad_norm): logging.warning('grad norm is nan. Do not update model.') else: optimizer.step() optimizer.zero_grad() if step % hp.summary_interval == 0: #torch.cuda.empty_cache() pbar.set_description( "Average Loss %.04f Loss %.04f | step %d" % (running_loss / j, loss.item(), step)) print("Losses :") for r in report_dict: for k, v in r.items(): if k == 'l1_loss': print("\nL1 loss :", v) if k == 'before_loss': print("\nBefore loss :", v) if k == 'after_loss': print("\nAfter loss :", v) if k == 'duration_loss': print("\nD loss :", v) if k == 'pitch_loss': print("\nP loss :", v) if k == 'energy_loss': print("\nE loss :", v) if k is not None and v is not None: if 'cupy' in str(type(v)): v = v.get() if 'cupy' in str(type(k)): k = k.get() writer.add_scalar("main/{}".format(k), v, step) if step % hp.validation_step == 0: plot_class = model.attention_plot_class plot_fn = plot_class(args.outdir + '/att_ws', device) for valid in validloader: x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid model.eval() with torch.no_grad(): loss_, report_dict_ = model(x_.cuda(), input_length_.cuda(), y_.cuda(), out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda()) att_ws = model.calculate_all_attentions( x_.cuda(), input_length_.cuda(), y_.cuda(), out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda()) model.train() print(" Validation Losses :") for r in report_dict_: for k, v in r.items(): if k == 'l1_loss': print("\nL1 loss :", v) if k == 'before_loss': print("\nBefore loss :", v) if k == 'after_loss': print("\nAfter loss :", v) if k == 'duration_loss': print("\nD loss :", v) if k == 'pitch_loss': print("\nP loss :", v) if k == 'energy_loss': print("\nE loss :", v) if k is not None and v is not None: if 'cupy' in str(type(v)): v = v.get() if 'cupy' in str(type(k)): k = k.get() for r in report_dict_: for k, v in r.items(): if k is not None and v is not None: if 'cupy' in str(type(v)): v = v.get() if 'cupy' in str(type(k)): k = k.get() writer.add_scalar("validation/{}".format(k), v, step) plot_fn.__call__(step, input_length_, out_length_, att_ws) plot_fn.log_attentions(writer, step, input_length_, out_length_, att_ws) if step % hp.save_interval == 0: save_path = os.path.join( hp.chkpt_dir, 'checkpoint_model_{}k_steps.pyt'.format(step // 1000)) optim_path = os.path.join( hp.chkpt_dir, 'checkpoint_optim_{}k_steps.pyt'.format(step // 1000)) torch.save(model.state_dict(), save_path) torch.save(optimizer.state_dict(), optim_path) print("Model Saved") print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))