def main(args): stft = TacotronSTFT(filter_length=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, n_mel_channels=hp.n_mels, sampling_rate=hp.sample_rate, mel_fmin=hp.fmin, mel_fmax=hp.fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) mel_path = os.path.join(hp.data_dir, 'mels') energy_path = os.path.join(hp.data_dir, 'energy') pitch_path = os.path.join(hp.data_dir, 'pitch') os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) p = pitch(wav) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] id = os.path.basename(wavpath).split(".")[0] np.save('{}/{}.npy'.format(mel_path, id), mel.numpy(), allow_pickle=False) np.save('{}/{}.npy'.format(energy_path, id), e.numpy(), allow_pickle=False) np.save('{}/{}.npy'.format(pitch_path, id), p, allow_pickle=False)
def preprocess(data_path, hp, file): stft = TacotronSTFT( filter_length=hp.audio.n_fft, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mels, sampling_rate=hp.audio.sample_rate, mel_fmin=hp.audio.fmin, mel_fmax=hp.audio.fmax, ) mel_path = os.path.join(hp.data.data_dir, "mels") energy_path = os.path.join(hp.data.data_dir, "energy") pitch_path = os.path.join(hp.data.data_dir, "pitch") avg_mel_phon = os.path.join(hp.data.data_dir, "avg_mel_ph") os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) os.makedirs(avg_mel_phon, exist_ok=True) print("Sample Rate : ", hp.audio.sample_rate) with open("{}".format(file), encoding="utf-8") as f: _metadata = [line.strip().split("|") for line in f] for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"): wavpath = os.path.join(data_path, metadata[4]) sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) dur = str_to_int_list(metadata[2]) dur = torch.from_numpy(np.array(dur)) p = pitch(wav, hp) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] avg_mel_ph = _average_mel_by_duration(mel, dur) # [num_mel, L] assert (avg_mel_ph.shape[0] == dur.shape[-1]) id = os.path.basename(wavpath).split(".")[0] np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False) np.save("{}/{}.npy".format(avg_mel_phon, id), avg_mel_ph.numpy(), allow_pickle=False)
def main(args, hp): stft = TacotronSTFT( filter_length=hp.audio.n_fft, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mels, sampling_rate=hp.audio.sample_rate, mel_fmin=hp.audio.fmin, mel_fmax=hp.audio.fmax, ) wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) mel_path = os.path.join(hp.data.data_dir, "mels") energy_path = os.path.join(hp.data.data_dir, "energy") pitch_path = os.path.join(hp.data.data_dir, "pitch") os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) print("Sample Rate : ", hp.audio.sample_rate) for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"): sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) p = pitch(wav, hp) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] id = os.path.basename(wavpath).split(".")[0] np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
def train(args, hp, hp_str, logger, vocoder): os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True) os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True) device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu") dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size, hp) validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True) idim = len(valid_symbols) odim = hp.audio.num_mels model = fastspeech.FeedForwardTransformer(idim, odim, hp) # set torch device model = model.to(device) print("Model is loaded ...") githash = get_commit_hash() if args.checkpoint_path is not None: if os.path.exists(args.checkpoint_path): logger.info("Resuming from checkpoint: %s" % args.checkpoint_path) checkpoint = torch.load(args.checkpoint_path) model.load_state_dict(checkpoint["model"]) optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) optimizer.load_state_dict(checkpoint["optim"]) global_step = checkpoint["step"] if hp_str != checkpoint["hp_str"]: logger.warning( "New hparams is different from checkpoint. Will use new.") if githash != checkpoint["githash"]: logger.warning( "Code might be different: git hash is different.") logger.warning("%s -> %s" % (checkpoint["githash"], githash)) else: print("Checkpoint does not exixts") global_step = 0 return None else: print("New Training") global_step = 0 optimizer = get_std_opt( model, hp.model.adim, hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) print("Batch Size :", hp.train.batch_size) num_params(model) os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True) writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name)) model.train() forward_count = 0 # print(model) for epoch in range(hp.train.epochs): start = time.time() running_loss = 0 j = 0 pbar = tqdm.tqdm(dataloader, desc="Loading train data") for data in pbar: global_step += 1 x, input_length, y, _, out_length, _, dur, e, p = data # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] loss, report_dict = model( x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda(), dur.cuda(), e.cuda(), p.cuda(), ) loss = loss.mean() / hp.train.accum_grad running_loss += loss.item() loss.backward() # update parameters forward_count += 1 j = j + 1 if forward_count != hp.train.accum_grad: continue forward_count = 0 step = global_step # compute the gradient norm to check if it is normal or not grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.train.grad_clip) logging.debug("grad norm={}".format(grad_norm)) if math.isnan(grad_norm): logging.warning("grad norm is nan. Do not update model.") else: optimizer.step() optimizer.zero_grad() if step % hp.train.summary_interval == 0: pbar.set_description( "Average Loss %.04f Loss %.04f | step %d" % (running_loss / j, loss.item(), step)) for r in report_dict: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("main/{}".format(k), v, step) if step % hp.train.validation_step == 0: for valid in validloader: x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid model.eval() with torch.no_grad(): loss_, report_dict_ = model( x_.cuda(), input_length_.cuda(), y_.cuda(), out_length_.cuda(), dur_.cuda(), e_.cuda(), p_.cuda(), ) mels_ = model.inference(x_[-1].cuda()) # [T, num_mel] model.train() for r in report_dict_: for k, v in r.items(): if k is not None and v is not None: if "cupy" in str(type(v)): v = v.get() if "cupy" in str(type(k)): k = k.get() writer.add_scalar("validation/{}".format(k), v, step) mels_ = mels_.T # Out: [num_mels, T] writer.add_image( "melspectrogram_target_{}".format(ids_[-1]), plot_spectrogram_to_numpy( y_[-1].T.data.cpu().numpy()[:, :out_length_[-1]]), step, dataformats="HWC", ) writer.add_image( "melspectrogram_prediction_{}".format(ids_[-1]), plot_spectrogram_to_numpy(mels_.data.cpu().numpy()), step, dataformats="HWC", ) # print(mels.unsqueeze(0).shape) audio = generate_audio( mels_.unsqueeze(0), vocoder ) # selecting the last data point to match mel generated above audio = audio.cpu().float().numpy() audio = audio / (audio.max() - audio.min() ) # get values between -1 and 1 writer.add_audio( tag="generated_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(audio), global_step=step, sample_rate=hp.audio.sample_rate, ) _, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate, ) writer.add_audio( tag=" target_audio_{}".format(ids_[-1]), snd_tensor=torch.Tensor(target), global_step=step, sample_rate=hp.audio.sample_rate, ) ## if step % hp.train.save_interval == 0: avg_p, avg_e, avg_d = evaluate(hp, validloader, model) writer.add_scalar("evaluation/Pitch Loss", avg_p, step) writer.add_scalar("evaluation/Energy Loss", avg_e, step) writer.add_scalar("evaluation/Dur Loss", avg_d, step) save_path = os.path.join( hp.train.chkpt_dir, args.name, "{}_fastspeech_{}_{}k_steps.pyt".format( args.name, githash, step // 1000), ) torch.save( { "model": model.state_dict(), "optim": optimizer.state_dict(), "step": step, "hp_str": hp_str, "githash": githash, }, save_path, ) logger.info("Saved checkpoint to: %s" % save_path) print("Time taken for epoch {} is {} sec\n".format( epoch + 1, int(time.time() - start)))