class MyWriter(SummaryWriter): def __init__(self, hp, logdir): super(MyWriter, self).__init__(logdir) self.sample_rate = hp.audio.sampling_rate self.stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) self.is_first = True def log_training(self, g_loss, d_loss, adv_loss, step): self.add_scalar('train.g_loss', g_loss, step) self.add_scalar('train.d_loss', d_loss, step) self.add_scalar('train.adv_loss', adv_loss, step) def log_validation(self, g_loss, d_loss, adv_loss, generator, discriminator, target, prediction, step): self.add_scalar('validation.g_loss', g_loss, step) self.add_scalar('validation.d_loss', d_loss, step) self.add_scalar('validation.adv_loss', adv_loss, step) self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate) self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step) wav = torch.from_numpy(prediction).unsqueeze(0) mel = self.stft.mel_spectrogram(wav) # mel [1, num_mel, T] self.add_image('melspectrogram_prediction', plot_spectrogram_to_numpy( mel.squeeze(0).data.cpu().numpy()), step, dataformats='HWC') self.log_histogram(generator, step) self.log_histogram(discriminator, step) if self.is_first: self.add_audio('raw_audio_target', target, step, self.sample_rate) self.add_image('waveform_target', plot_waveform_to_numpy(target), step) wav = torch.from_numpy(target).unsqueeze(0) mel = self.stft.mel_spectrogram(wav) # mel [1, num_mel, T] self.add_image('melspectrogram_target', plot_spectrogram_to_numpy( mel.squeeze(0).data.cpu().numpy()), step, dataformats='HWC') self.is_first = False def log_evaluation(self, generated, step, name): self.add_audio(f'evaluation/{name}', generated, step, self.sample_rate) def log_histogram(self, model, step): for tag, value in model.named_parameters(): self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)
def main(hp, args): stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) mel_path = hp.data.mel_path os.makedirs(mel_path, exist_ok=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) assert sr == hp.audio.sampling_rate, \ "sample rate mismatch. expected %d, got %d at %s" % \ (hp.audio.sampling_rate, sr, wavpath) if len(wav) < hp.audio.segment_length + hp.audio.pad_short: wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \ mode='constant', constant_values=0.0) wav = torch.from_numpy(wav).unsqueeze(0) mel = stft.mel_spectrogram(wav) # mel [1, num_mel, T] mel = mel.squeeze(0) # [num_mel, T] id = os.path.basename(wavpath).split(".")[0] np.save('{}/{}.npy'.format(mel_path, id), mel.numpy(), allow_pickle=False)
def main(hp, args): stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.flac'), recursive=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_flac_np(wavpath) assert sr == hp.audio.sampling_rate, \ "sample rate mismatch. expected %d, got %d at %s" % \ (hp.audio.sampling_rate, sr, wavpath) if len(wav) < hp.audio.segment_length + hp.audio.pad_short: wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \ mode='constant', constant_values=0.0) wav = torch.from_numpy(wav).unsqueeze(0) mel = stft.mel_spectrogram(wav) melpath = wavpath.replace('.flac', '.mel') torch.save(mel, melpath)
def main(hp, args): stft = TacotronSTFT(filter_length=hp.filter_length, hop_length=hp.hop_length, win_length=hp.win_length, n_mel_channels=hp.n_mel_channels, sampling_rate=hp.sampling_rate, mel_fmin=hp.mel_fmin, mel_fmax=hp.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) assert sr == hp.sampling_rate, \ "sample rate mismatch. expected %d, got %d at %s" % \ (hp.sampling_rate, sr, wavpath) if len(wav) < hp.segment_length + hp.pad_short: wav = np.pad(wav, (0, hp.segment_length + hp.pad_short - len(wav)), \ mode='constant', constant_values=0.0) # f0, _ = pw.dio(wav.astype(np.float64), hp.sampling_rate, frame_period=hp.hop_length/hp.sampling_rate*1000) wav = torch.from_numpy(wav).unsqueeze(0) mel, energy = stft.mel_spectrogram(wav) # f0 = f0[:sum(duration)] melpath = os.path.join(args.out_path, os.path.basename(wavpath.replace('.wav', '.npy'))) # amppath = os.path.join(args.out_path, os.path.basename(wavpath.replace('.wav', '_amp.npy'))) np.save(melpath, mel.squeeze(0).transpose(0,1).numpy())
def extract_mel(wav_dir, fs, speaker_id_pos, filter_length, hop_length, mel_fmin, mel_fmax): """ Extract mel-spectrogram from audio waveform """ mel_extractor = TacotronSTFT(filter_length=filter_length, hop_length=hop_length, n_mel_channels=80, sampling_rate=fs, mel_fmin=mel_fmin, mel_fmax=mel_fmax).cuda() audio_dataset = AudioDataset(wav_dir=wav_dir, n_speaker=None, speaker_id_pos=speaker_id_pos) audio_collate_fn = AudioCollateFn() audio_loader = DataLoader(audio_dataset, batch_size=32, shuffle=False, collate_fn=audio_collate_fn, num_workers=0) for batch in progressbar(audio_loader): audio, lengths, audio_path = batch audio = audio.cuda(0) mel = mel_extractor.mel_spectrogram(audio).detach().cpu().numpy() for i, input_wav in enumerate(audio_path): mel_fname = input_wav.split("/")[-1].replace(".wav", "") output_mel_dir = input_wav.replace("wav16", "mel16k") output_mel_dir = "/".join(output_mel_dir.split("/")[:-1]) if not exists(output_mel_dir): os.makedirs(output_mel_dir) _mel = mel[i, :, :int(math.ceil(lengths[i] / hop_length))] np.save(join(output_mel_dir, mel_fname), _mel)
def main(args): stft = TacotronSTFT(filter_length=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, n_mel_channels=hp.n_mels, sampling_rate=hp.sample_rate, mel_fmin=hp.fmin, mel_fmax=hp.fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) mel_path = os.path.join(hp.data_dir, 'mels') energy_path = os.path.join(hp.data_dir, 'energy') pitch_path = os.path.join(hp.data_dir, 'pitch') os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) p = pitch(wav) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] id = os.path.basename(wavpath).split(".")[0] np.save('{}/{}.npy'.format(mel_path, id), mel.numpy(), allow_pickle=False) np.save('{}/{}.npy'.format(energy_path, id), e.numpy(), allow_pickle=False) np.save('{}/{}.npy'.format(pitch_path, id), p, allow_pickle=False)
def preprocess(data_path, hp, file): stft = TacotronSTFT( filter_length=hp.audio.n_fft, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mels, sampling_rate=hp.audio.sample_rate, mel_fmin=hp.audio.fmin, mel_fmax=hp.audio.fmax, ) mel_path = os.path.join(hp.data.data_dir, "mels") energy_path = os.path.join(hp.data.data_dir, "energy") pitch_path = os.path.join(hp.data.data_dir, "pitch") avg_mel_phon = os.path.join(hp.data.data_dir, "avg_mel_ph") os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) os.makedirs(avg_mel_phon, exist_ok=True) print("Sample Rate : ", hp.audio.sample_rate) with open("{}".format(file), encoding="utf-8") as f: _metadata = [line.strip().split("|") for line in f] for metadata in tqdm.tqdm(_metadata, desc="preprocess wav to mel"): wavpath = os.path.join(data_path, metadata[4]) sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) dur = str_to_int_list(metadata[2]) dur = torch.from_numpy(np.array(dur)) p = pitch(wav, hp) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] avg_mel_ph = _average_mel_by_duration(mel, dur) # [num_mel, L] assert (avg_mel_ph.shape[0] == dur.shape[-1]) id = os.path.basename(wavpath).split(".")[0] np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False) np.save("{}/{}.npy".format(avg_mel_phon, id), avg_mel_ph.numpy(), allow_pickle=False)
class Mel2Samp(torch.utils.data.Dataset): """ This is the main class that calculates the spectrogram and returns the spectrogram, audio pair. """ def __init__(self, training_files, segment_length, filter_length, hop_length, win_length, sampling_rate, mel_fmin, mel_fmax): self.audio_files = files_to_list(training_files) random.seed(1234) random.shuffle(self.audio_files) self.stft = TacotronSTFT(filter_length=filter_length, hop_length=hop_length, win_length=win_length, sampling_rate=sampling_rate, mel_fmin=mel_fmin, mel_fmax=mel_fmax) self.segment_length = segment_length self.sampling_rate = sampling_rate def get_mel(self, audio): audio_norm = audio / MAX_WAV_VALUE audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) melspec = self.stft.mel_spectrogram(audio_norm) melspec = torch.squeeze(melspec, 0) return melspec def __getitem__(self, index): # Read audio filename = self.audio_files[index] audio, sampling_rate = load_wav_to_torch(filename) if sampling_rate != self.sampling_rate: raise ValueError("{} SR doesn't match target {} SR".format( sampling_rate, self.sampling_rate)) # Take segment if audio.size(0) >= self.segment_length: max_audio_start = audio.size(0) - self.segment_length audio_start = random.randint(0, max_audio_start) audio = audio[audio_start:audio_start+self.segment_length] else: audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data mel = self.get_mel(audio) audio = audio / MAX_WAV_VALUE return (mel, audio) def __len__(self): return len(self.audio_files)
def main(hp, args): stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) save_train_mel_path = 'melgan_train_mel_data' save_val_mel_path = 'melgan_val_mel_data' os.makedirs(save_train_mel_path, exist_ok=True) os.makedirs(save_val_mel_path, exist_ok=True) random.shuffle(wav_files) count = 0 #for wavpath in wav_files: for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) assert sr == hp.audio.sampling_rate, \ "sample rate mismatch. expected %d, got %d at %s" % \ (hp.audio.sampling_rate, sr, wavpath) if len(wav) < hp.audio.segment_length + hp.audio.pad_short: wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \ mode='constant', constant_values=0.0) wav = torch.from_numpy(wav).unsqueeze(0) mel = stft.mel_spectrogram(wav) wav_name = wavpath.split('/')[5] melpath = wavpath.replace('.wav', '.mel') mel_name = melpath.split('/')[5] if count < 300: final_mel_path = os.path.join(save_val_mel_path, mel_name) final_wav_path = os.path.join(save_val_mel_path, wav_name) else: final_mel_path = os.path.join(save_train_mel_path, mel_name) final_wav_path = os.path.join(save_train_mel_path, wav_name) torch.save(mel, final_mel_path) shutil.copy(wavpath, final_wav_path) count += 1
def main(args, hp): stft = TacotronSTFT( filter_length=hp.audio.n_fft, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mels, sampling_rate=hp.audio.sample_rate, mel_fmin=hp.audio.fmin, mel_fmax=hp.audio.fmax, ) wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) mel_path = os.path.join(hp.data.data_dir, "mels") energy_path = os.path.join(hp.data.data_dir, "energy") pitch_path = os.path.join(hp.data.data_dir, "pitch") os.makedirs(mel_path, exist_ok=True) os.makedirs(energy_path, exist_ok=True) os.makedirs(pitch_path, exist_ok=True) print("Sample Rate : ", hp.audio.sample_rate) for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"): sr, wav = read_wav_np(wavpath, hp.audio.sample_rate) p = pitch(wav, hp) # [T, ] T = Number of frames wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] mag = mag.squeeze(0) # [num_mag, T] e = torch.norm(mag, dim=0) # [T, ] p = p[:mel.shape[1]] id = os.path.basename(wavpath).split(".")[0] np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False) np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
def main(hp, args): stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath) assert sr == hp.audio.sampling_rate, \ "sample rate mismatch. expected %d, got %d at %s" % \ (hp.audio.sampling_rate, sr, wavpath) wav = torch.from_numpy(wav).unsqueeze(0) mel = stft.mel_spectrogram(wav) melpath = wavpath.replace('.wav', '.mel') torch.save(mel, melpath)
def main(args, hp): stft = TacotronSTFT(filter_length=hp.filter_length, hop_length=hp.hop_length, win_length=hp.win_length, n_mel_channels=hp.n_mel_channels, sampling_rate=hp.sampling_rate, mel_fmin=hp.mel_fmin, mel_fmax=hp.mel_fmax) wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True) mel_path = hp.data_path os.makedirs(mel_path, exist_ok=True) print("Sample Rate : ", hp.sampling_rate) for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'): sr, wav = read_wav_np(wavpath, hp.sampling_rate) wav = torch.from_numpy(wav).unsqueeze(0) mel, mag = stft.mel_spectrogram( wav) # mel [1, 80, T] mag [1, num_mag, T] mel = mel.squeeze(0) # [num_mel, T] id = os.path.basename(wavpath).split(".")[0] np.save('{}/{}.npy'.format(mel_path, id), mel.numpy(), allow_pickle=False)
def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str): model_g = Generator(hp.audio.n_mel_channels).cuda() model_d = MultiScaleDiscriminator(hp.model.num_D, hp.model.ndf, hp.model.n_layers, hp.model.downsampling_factor, hp.model.disc_out).cuda() model_d_mpd = MPD().cuda() optim_g = torch.optim.Adam(model_g.parameters(), lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) optim_d = torch.optim.Adam(itertools.chain(model_d.parameters(), model_d_mpd.parameters()), lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2)) stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) # githash = get_commit_hash() init_epoch = -1 step = 0 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model_g.load_state_dict(checkpoint['model_g']) model_d.load_state_dict(checkpoint['model_d']) model_d_mpd.load_state_dict(checkpoint['model_d_mpd']) optim_g.load_state_dict(checkpoint['optim_g']) optim_d.load_state_dict(checkpoint['optim_d']) step = checkpoint['step'] init_epoch = checkpoint['epoch'] if hp_str != checkpoint['hp_str']: logger.warning( "New hparams is different from checkpoint. Will use new.") # if githash != checkpoint['githash']: # logger.warning("Code might be different: git hash is different.") # logger.warning("%s -> %s" % (checkpoint['githash'], githash)) else: logger.info("Starting new training run.") # this accelerates training when the size of minibatch is always consistent. # if not consistent, it'll horribly slow down. torch.backends.cudnn.benchmark = True try: model_g.train() model_d.train() stft_loss = MultiResolutionSTFTLoss() criterion = torch.nn.MSELoss().cuda() l1loss = torch.nn.L1Loss() for epoch in itertools.count(init_epoch + 1): if epoch % hp.log.validation_interval == 0: with torch.no_grad(): validate(hp, model_g, model_d, model_d_mpd, valloader, stft_loss, l1loss, criterion, stft, writer, step) trainloader.dataset.shuffle_mapping() loader = tqdm.tqdm(trainloader, desc='Loading train data') avg_g_loss = [] avg_d_loss = [] avg_adv_loss = [] for (melG, audioG), (melD, audioD) in loader: melG = melG.cuda() # torch.Size([16, 80, 64]) audioG = audioG.cuda() # torch.Size([16, 1, 16000]) melD = melD.cuda() # torch.Size([16, 80, 64]) audioD = audioD.cuda() # torch.Size([16, 1, 16000] # generator optim_g.zero_grad() fake_audio = model_g( melG)[:, :, :hp.audio. segment_length] # torch.Size([16, 1, 12800]) loss_g = 0.0 sc_loss, mag_loss = stft_loss( fake_audio[:, :, :audioG.size(2)].squeeze(1), audioG.squeeze(1)) loss_g += sc_loss + mag_loss # STFT Loss adv_loss = 0.0 loss_mel = 0.0 if step > hp.train.discriminator_train_start_steps: disc_real = model_d(audioG) disc_fake = model_d(fake_audio) # for multi-scale discriminator for feats_fake, score_fake in disc_fake: # adv_loss += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2])) adv_loss += criterion(score_fake, torch.ones_like(score_fake)) adv_loss = adv_loss / len(disc_fake) # len(disc_fake) = 3 # MPD Adverserial loss out1, out2, out3, out4, out5 = model_d_mpd(fake_audio) adv_mpd_loss = criterion(out1, torch.ones_like(out1)) + criterion(out2, torch.ones_like(out2)) + \ criterion(out3, torch.ones_like(out3)) + criterion(out4, torch.ones_like(out4)) + \ criterion(out5, torch.ones_like(out5)) adv_mpd_loss = adv_mpd_loss / 5 adv_loss = adv_loss + adv_mpd_loss # Adv Loss # Mel Loss mel_fake = stft.mel_spectrogram(fake_audio.squeeze(1)) loss_mel += l1loss(melG[:, :, :mel_fake.size(2)], mel_fake.cuda()) # Mel L1 loss loss_g += hp.model.lambda_mel * loss_mel if hp.model.feat_loss: for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real): for feat_f, feat_r in zip(feats_fake, feats_real): adv_loss += hp.model.feat_match * torch.mean( torch.abs(feat_f - feat_r)) loss_g += hp.model.lambda_adv * adv_loss loss_g.backward() optim_g.step() # discriminator loss_d_avg = 0.0 if step > hp.train.discriminator_train_start_steps: fake_audio = model_g(melD)[:, :, :hp.audio.segment_length] fake_audio = fake_audio.detach() loss_d_sum = 0.0 for _ in range(hp.train.rep_discriminator): optim_d.zero_grad() disc_fake = model_d(fake_audio) disc_real = model_d(audioD) loss_d = 0.0 loss_d_real = 0.0 loss_d_fake = 0.0 for (_, score_fake), (_, score_real) in zip( disc_fake, disc_real): loss_d_real += criterion( score_real, torch.ones_like(score_real)) loss_d_fake += criterion( score_fake, torch.zeros_like(score_fake)) loss_d_real = loss_d_real / len( disc_real) # len(disc_real) = 3 loss_d_fake = loss_d_fake / len( disc_fake) # len(disc_fake) = 3 loss_d += loss_d_real + loss_d_fake # MSD loss loss_d_sum += loss_d # MPD Adverserial loss out1, out2, out3, out4, out5 = model_d_mpd(fake_audio) out1_real, out2_real, out3_real, out4_real, out5_real = model_d_mpd( audioD) loss_mpd_fake = criterion(out1, torch.zeros_like(out1)) + criterion(out2, torch.zeros_like(out2)) + \ criterion(out3, torch.zeros_like(out3)) + criterion(out4, torch.zeros_like(out4)) + \ criterion(out5, torch.zeros_like(out5)) loss_mpd_real = criterion(out1_real, torch.ones_like(out1_real)) + criterion(out2_real, torch.ones_like(out2_real)) + \ criterion(out3_real, torch.ones_like(out3_real)) + criterion(out4_real, torch.ones_like(out4_real)) + \ criterion(out5_real, torch.ones_like(out5_real)) loss_mpd = (loss_mpd_fake + loss_mpd_real) / 5 # MPD Loss loss_d += loss_mpd loss_d.backward() optim_d.step() loss_d_sum += loss_mpd loss_d_avg = loss_d_sum / hp.train.rep_discriminator loss_d_avg = loss_d_avg.item() step += 1 # logging loss_g = loss_g.item() avg_g_loss.append(loss_g) avg_d_loss.append(loss_d_avg) avg_adv_loss.append(adv_loss) if any([ loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg) ]): logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step)) raise Exception("Loss exploded") if step % hp.log.summary_interval == 0: writer.log_training(loss_g, loss_d_avg, adv_loss, loss_mel, step) loader.set_description( "Avg : g %.04f d %.04f ad %.04f| step %d" % (sum(avg_g_loss) / len(avg_g_loss), sum(avg_d_loss) / len(avg_d_loss), sum(avg_adv_loss) / len(avg_adv_loss), step)) if epoch % hp.log.save_interval == 0: save_path = os.path.join(pt_dir, '%s_%04d.pt' % (args.name, epoch)) torch.save( { 'model_g': model_g.state_dict(), 'model_d': model_d.state_dict(), 'model_d_mpd': model_d_mpd.state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'step': step, 'epoch': epoch, 'hp_str': hp_str }, save_path) logger.info("Saved checkpoint to: %s" % save_path) except Exception as e: logger.info("Exiting due to exception: %s" % e) traceback.print_exc()
def train(args, chkpt_dir, chkpt_path, writer, logger, hp, hp_str, seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = ( True # https://github.com/pytorch/pytorch/issues/6351 ) if args.amp: print("Automatic Mixed Precision Training") scaler = amp.GradScaler() criterion = WaveFlowLoss(hp.model.sigma) model = WaveFlow(hp.model.flows, hp.model.n_group, hp.audio.sampling_rate, hp.audio.win_length, hp.audio.n_mel_channels, hp).cuda() num_params(model) optimizer = torch.optim.Adam(model.parameters(), lr=hp.train.adam.lr) # Load checkpoint if one exists githash = get_commit_hash() init_epoch = -1 step = 0 if chkpt_path is not None: #logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optim']) step = checkpoint['step'] init_epoch = checkpoint['epoch'] if hp_str != checkpoint['hp_str']: logger.warning( "New hparams is different from checkpoint. Will use new.") if githash != checkpoint['githash']: logger.warning("Code might be different: git hash is different.") logger.warning("%s -> %s" % (checkpoint['githash'], githash)) else: logger.info("Starting new training run.") train_loader = create_dataloader(hp, True) valid_loader = create_dataloader(hp, False) # Get shared output_directory ready stft = TacotronSTFT(filter_length=hp.audio.filter_length, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length, n_mel_channels=hp.audio.n_mel_channels, sampling_rate=hp.audio.sampling_rate, mel_fmin=hp.audio.mel_fmin, mel_fmax=hp.audio.mel_fmax) model.train() # ================ MAIN TRAINNIG LOOP! =================== for epoch in itertools.count(init_epoch + 1): if epoch % hp.log.validation_interval == 0: with torch.no_grad(): pass loader = tqdm.tqdm(train_loader, desc='Loading train data') loss_list = [] for (mel, audio) in loader: model.zero_grad() #mel, audio = batch mel = torch.autograd.Variable( mel.cuda()) # [B, num mel, num of frame] audio = torch.autograd.Variable(audio.cuda()) # [B, T] if args.amp: with amp.autocast(): z, logdet, _ = model(audio, mel) # [B, T] loss = criterion(z, logdet) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: z, logdet, _ = model(audio, mel) # [B, T] loss = criterion(z, logdet) loss.backward() optimizer.step() loss_list.append(loss.item()) loader.set_description("Avg Loss : loss %.04f | step %d" % (sum(loss_list) / len(loss_list), step)) if step % hp.log.summary_interval == 0: writer.add_scalar("train.loss", loss.item(), step) writer.add_scalar('train.log_determinant', logdet.mean().item()) writer.add_scalar('train.z_mean', z.mean().item()) writer.add_scalar('train.z_std', z.std().item()) if step % hp.log.validation_interval == 0: for (mel_, audio_) in valid_loader: model.eval() x, logdet_ = model.infer(mel_.cuda(), hp.model.sigma) # x -> [T] torch.clamp(x, -1, 1, out=x) writer.add_scalar('valid.log_determinant', logdet_.mean().item()) writer.add_audio('actual_audio', audio_.squeeze(0).cpu().detach().numpy(), step, sample_rate=hp.audio.sampling_rate) writer.add_audio('reconstruct_audio', x.cpu().detach().numpy(), step, sample_rate=hp.audio.sampling_rate) mel_spec = mel_[0].cpu().detach() mel_spec -= mel_spec.min() mel_spec /= mel_spec.max() writer.add_image('actual_mel-spectrum', plot_spectrogram_to_numpy( mel_spec.numpy()), step, dataformats='HWC') mel_gen, _ = stft.mel_spectrogram(x.unsqueeze(0)) mel_g_spec = mel_gen[0].cpu().detach() mel_g_spec -= mel_g_spec.min() mel_g_spec /= mel_g_spec.max() writer.add_image('gen_mel-spectrum', plot_spectrogram_to_numpy( mel_g_spec.numpy()), step, dataformats='HWC') model.train() break step += 1 if epoch % hp.log.save_interval == 0: save_path = os.path.join( chkpt_dir, '%s_%s_%04d.pt' % (args.name, githash, epoch)) torch.save( { 'model': model.state_dict(), 'optim': optimizer.state_dict(), 'step': step, 'epoch': epoch, 'hp_str': hp_str, 'githash': githash, }, save_path) logger.info("Saved checkpoint to: %s" % save_path)