def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path): k = model.get_step() // 1000 for i, (m, x) in enumerate(test_set, 1): if i > samples: break print('\n| Generating: %i/%i' % (i, samples)) x = x[0].numpy() bits = 16 if hp.voc_mode == 'MOL' else hp.bits if hp.mu_law and hp.voc_mode != 'MOL': x = decode_mu_law(x, 2**bits, from_labels=True) else: x = label_2_float(x, bits) save_wav(x, save_path / f'{k}k_steps_{i}_target.wav') batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED' save_str = str(save_path / f'{k}k_steps_{i}_{batch_str}.wav') _ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
def gen_from_file(model: WaveRNN, load_path: Path, save_path: Path, batched, target, overlap): k = model.get_step() // 1000 file_name = load_path.stem suffix = load_path.suffix if suffix == ".wav": wav = load_wav(load_path) save_wav(wav, save_path / f'__{file_name}__{k}k_steps_target.wav') mel = melspectrogram(wav) elif suffix == ".npy": mel = np.load(load_path) if mel.ndim != 2 or mel.shape[0] != hp.num_mels: raise ValueError( f'Expected a numpy array shaped (n_mels, n_hops), but got {wav.shape}!' ) _max = np.max(mel) _min = np.min(mel) if _max >= 1.01 or _min <= -0.01: raise ValueError( f'Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]' ) else: raise ValueError( f"Expected an extension of .wav or .npy, but got {suffix}!") mel = torch.tensor(mel).unsqueeze(0) batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED' save_str = save_path / f'__{file_name}__{k}k_steps_{batch_str}.wav' _ = model.generate(mel, save_str, batched, target, overlap, hp.mu_law)
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path): k = model.get_step() // 1000 for i, (m, x) in enumerate(test_set, 1): if i > samples: break print('\n| Generating: %i/%i' % (i, samples)) x = x[0].numpy() bits = 16 if hp.voc_mode == 'MOL' else hp.bits if hp.mu_law and hp.voc_mode != 'MOL' : x = decode_mu_law(x, 2**bits, from_labels=True) else : x = label_2_float(x, bits) save_wav(x, save_path.joinpath("%dk_steps_%d_target.wav" % (k, i))) batch_str = "gen_batched_target%d_overlap%d" % (target, overlap) if batched else \ "gen_not_batched" save_str = save_path.joinpath("%dk_steps_%d_%s.wav" % (k, i, batch_str)) wav = model.generate(m, batched, target, overlap, hp.mu_law) save_wav(wav, save_str)
class Model(object): def __init__(self): self._model = None def load_from(self, weights_fpath, verbose=True): if verbose: print("Building Wave-RNN") self._model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode) #.cuda() if verbose: print("Loading model weights at %s" % weights_fpath) checkpoint = torch.load(weights_fpath, map_location=torch.device('cpu')) self._model.load_state_dict(checkpoint['model_state']) self._model.eval() def is_loaded(self): return self._model is not None def infer_waveform(self, mel, normalize=True, batched=True, target=8000, overlap=800, progress_callback=None): """ Infers the waveform of a mel spectrogram output by the synthesizer (the format must match that of the synthesizer!) :param normalize: :param batched: :param target: :param overlap: :return: """ if self._model is None: raise Exception("Please load Wave-RNN in memory before using it") if normalize: mel = mel / hp.mel_max_abs_value mel = torch.from_numpy(mel[None, ...]) wav = self._model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback) return wav
def gen_meltest( model: WaveRNN, batched, target, overlap, save_path ): mel = [] mel.append( np.load("/home/sdevgupta/mine/waveglow/outputs/waveglow_specs/mel-1.npy").T ) mel.append( np.load("/home/sdevgupta/mine/waveglow/outputs/waveglow_specs/mel-3.npy").T ) mel.append( np.load("/home/sdevgupta/mine/waveglow/outputs/waveglow_specs/mel-5.npy").T ) k = model.get_step() // 1000 for i,m in enumerate(mel): m = m - 20 m = audio_synth._normalize(m, hparams_synth.hparams)/4 wav = model.generate_from_mel( m, batched=False, overlap=hp.voc_overlap, target=hp.voc_target, mu_law=True, cpu=False, apply_preemphasis=False ) #wav = wav / np.abs(wav).max() * 0.9 save_str = save_path.joinpath( "mel-"+str(i+1)+"-steps-"+str(k)+"k.wav" ) save_wav(wav, save_str)
def load_model(weights_fpath, verbose=True): global _model, _device if verbose: print("Building Wave-RNN") _model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode) if torch.cuda.is_available(): _model = _model.cuda() _device = torch.device('cuda') else: _device = torch.device('cpu') if verbose: print("Loading model weights at %s" % weights_fpath) checkpoint = torch.load(weights_fpath, _device) _model.load_state_dict(checkpoint['model_state']) _model.eval()
def load_model(weights_fpath, verbose=True): global _model if verbose: print("Building Wave-RNN") _model = WaveRNN( rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode ) if verbose: print("Loading model weights at %s" % weights_fpath) checkpoint = torch.load(weights_fpath,map_location='cpu') _model.load_state_dict(checkpoint['model_state']) _model.eval()
gta = args.gta if not args.force_cpu and torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') print('Using device:', device) print('\nInitialising Model...\n') model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).to(device) paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights model.load(voc_weights) simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'), ('Target Samples', target if batched else 'N/A'), ('Overlap Samples', overlap if batched else 'N/A')])
def train(run_id: str, models_dir: Path, metadata_path: Path, weights_path: Path, ground_truth: bool, save_every: int, backup_every: int, force_restart: bool): # Check to make sure the hop length is correctly factorised assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length # Instantiate the model print("Initializing the model...") model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).cuda() # Initialize the optimizer optimizer = optim.Adam(model.parameters()) for p in optimizer.param_groups: p["lr"] = hp.voc_lr loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss # Load the weights model_dir = models_dir.joinpath(run_id) model_dir.mkdir(exist_ok=True) weights_fpath = weights_path metadata_fpath = metadata_path if force_restart: print("\nStarting the training of WaveRNN from scratch\n") model.save(weights_fpath, optimizer) else: print("\nLoading weights at %s" % weights_fpath) model.load(weights_fpath, optimizer) print("WaveRNN weights loaded from step %d" % model.step) # Initialize the dataset dataset = VocoderDataset(metadata_fpath) test_loader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True) # Begin the training simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr), ('Sequence Len', hp.voc_seq_len)]) epoch_start = int( (model.step - 428000) * 110 / dataset.get_number_of_samples()) epoch_end = 200 log_path = os.path.join(models_dir, "logs") if not os.path.isdir(log_path): os.mkdir(log_path) writer = SummaryWriter(log_path) print("Log path : " + log_path) print("Starting from epoch: " + str(epoch_start)) for epoch in range(epoch_start, epoch_start + epoch_end): data_loader = DataLoader(dataset, collate_fn=collate_vocoder, batch_size=hp.voc_batch_size, num_workers=2, shuffle=True, pin_memory=True) start = time.time() running_loss = 0. for i, (x, y, m) in enumerate(data_loader, 1): x, m, y = x.cuda(), m.cuda(), y.cuda() # Forward pass y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) # Backward pass loss = loss_func(y_hat, y) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() speed = i / (time.time() - start) avg_loss = running_loss / i step = model.get_step() k = step // 1000 if backup_every != 0 and step % backup_every == 0: model.checkpoint(model_dir, optimizer) # if save_every != 0 and step % save_every == 0 : # model.save(weights_fpath, optimizer) if step % 500 == 0: writer.add_scalar('Loss/train', avg_loss, round(step / 1000, 1)) msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \ f"Loss: {avg_loss:.4f} | {speed:.1f} " \ f"steps/s | Step: {k}k | " print(msg, flush=True) if step % 15000 == 0: gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, hp.voc_target, hp.voc_overlap, model_dir) gen_meltest(model, hp.voc_gen_batched, hp.voc_target, hp.voc_overlap, model_dir)
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps): # Use same device as model parameters device = next(model.parameters()).device for g in optimizer.param_groups: g['lr'] = lr total_iters = len(train_set) epochs = (total_steps - model.get_step()) // total_iters + 1 for e in range(1, epochs + 1): start = time.time() running_loss = 0. for i, (x, y, m) in enumerate(train_set, 1): x, m, y = x.to(device), m.to(device), y.to(device) # Parallelize model onto GPUS using workaround due to python bug if device.type == 'cuda' and torch.cuda.device_count() > 1: y_hat = data_parallel_workaround(model, x, m) else: y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) loss = loss_func(y_hat, y) optimizer.zero_grad() loss.backward() if hp.voc_clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), hp.voc_clip_grad_norm) if np.isnan(grad_norm): print('grad_norm was NaN!') optimizer.step() running_loss += loss.item() avg_loss = running_loss / i speed = i / (time.time() - start) step = model.get_step() k = step // 1000 if step % hp.voc_checkpoint_every == 0: gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, hp.voc_target, hp.voc_overlap, paths.voc_output) ckpt_name = f'wave_step{k}K' save_checkpoint('voc', paths, model, optimizer, name=ckpt_name, is_silent=True) msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | ' stream(msg) # Must save latest optimizer state to ensure that resuming training # doesn't produce artifacts save_checkpoint('voc', paths, model, optimizer, is_silent=True) model.log(paths.voc_log, msg) print(' ')
class Generator(nn.Module): """Generator network.""" def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, loss_content=True, discriminator=False, multigpu=False, lambda_gan=0.0001, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None): super(Generator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if discriminator: self.dis = PatchDiscriminator(n_class=num_speakers) self.dis_criterion = GANLoss(use_lsgan=use_lsgan, tensor=torch.cuda.FloatTensor) else: self.dis = None self.loss_content = loss_content self.lambda_gan = lambda_gan self.lambda_wavenet = lambda_wavenet self.multigpu = multigpu self.prepare_test(dim_spec, test_path_source, test_path_target) self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) if is_train: self.criterionIdt = torch.nn.L1Loss(reduction='mean') self.opt_encoder = torch.optim.Adam(self.encoder.parameters(), lr=lr) self.opt_decoder = torch.optim.Adam(itertools.chain( self.decoder.parameters(), self.postnet.parameters()), lr=lr) if discriminator: self.opt_dis = torch.optim.Adam(self.dis.parameters(), lr=lr) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) if self.dis is not None: self.dis = nn.DataParallel(self.dis) def prepare_test(self, dim_spec, source_path=None, target_path=None): if source_path is None: source_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/trump_02.wav" if target_path is None: target_path = "/mnt/lustre/dengkangle/cmu/datasets/audio/test/female.wav" # source_path = "/home/kangled/datasets/audio/Chaplin_01.wav" # target_path = "/home/kangled/datasets/audio/Obama_01.wav" mel_basis80 = librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=80) wav, sr = librosa.load(source_path, hparams.sample_rate) wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize) linear_spec = np.abs( librosa.stft(wav, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size)) mel_spec = mel_basis80.dot(linear_spec) mel_db = 20 * np.log10(mel_spec) source_spec = np.clip((mel_db + 120) / 125, 0, 1) # source_spec = mel_spec self.source_embed = torch.from_numpy(np.array([0, 1 ])).float().unsqueeze(0) self.source_wav = wav wav, sr = librosa.load(target_path, hparams.sample_rate) wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize) linear_spec = np.abs( librosa.stft(wav, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size)) mel_spec = mel_basis80.dot(linear_spec) mel_db = 20 * np.log10(mel_spec) target_spec = np.clip((mel_db + 120) / 125, 0, 1) # target_spec = mel_spec self.target_embed = torch.from_numpy(np.array([1, 0 ])).float().unsqueeze(0) self.target_wav = wav self.source_spec = torch.Tensor(pad_seq(source_spec.T, hparams.freq)).unsqueeze(0) self.target_spec = torch.Tensor(pad_seq(target_spec.T, hparams.freq)).unsqueeze(0) def test_fixed(self, device): with torch.no_grad(): t2s_spec = self.conversion(self.target_embed, self.source_embed, self.target_spec, device).cpu() s2s_spec = self.conversion(self.source_embed, self.source_embed, self.source_spec, device).cpu() s2t_spec = self.conversion(self.source_embed, self.target_embed, self.source_spec, device).cpu() t2t_spec = self.conversion(self.target_embed, self.target_embed, self.target_spec, device).cpu() ret_dic = {} ret_dic['A_fake_griffin'], sr = mel2wav(s2t_spec.numpy().squeeze(0).T) ret_dic['B_fake_griffin'], sr = mel2wav(t2s_spec.numpy().squeeze(0).T) ret_dic['A'] = self.source_wav ret_dic['B'] = self.target_wav with torch.no_grad(): if not self.multigpu: ret_dic['A_fake_w'] = inv_preemphasis( self.vocoder.generate(s2t_spec.to(device).transpose(2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) ret_dic['B_fake_w'] = inv_preemphasis( self.vocoder.generate(t2s_spec.to(device).transpose(2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) else: ret_dic['A_fake_w'] = inv_preemphasis( self.vocoder.module.generate(s2t_spec.to(device).transpose( 2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) ret_dic['B_fake_w'] = inv_preemphasis( self.vocoder.module.generate(t2s_spec.to(device).transpose( 2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) return ret_dic, sr def conversion(self, speaker_org, speaker_trg, spec, device, speed=1): speaker_org, speaker_trg, spec = speaker_org.to( device), speaker_trg.to(device), spec.to(device) if not self.multigpu: codes = self.encoder(spec, speaker_org) else: codes = self.encoder.module(spec, speaker_org) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand( -1, int(speed * spec.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) encoder_outputs = torch.cat((code_exp, speaker_trg.unsqueeze(1).expand( -1, code_exp.size(1), -1)), dim=-1) mel_outputs = self.decoder( code_exp) if not self.multigpu else self.decoder.module(code_exp) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) return mel_outputs_postnet def optimize_parameters(self, dataloader, epochs, device, display_freq=10, save_freq=1000, save_dir="./", experimentName="Train", load_model=None, initial_niter=0): writer = SummaryWriter(log_dir="logs/" + experimentName) if load_model is not None: print("Loading from %s..." % load_model) # self.load_state_dict(torch.load(load_model)) d = torch.load(load_model) newdict = d.copy() for key, value in d.items(): newkey = key if 'wavenet' in key: newdict[key.replace('wavenet', 'vocoder')] = newdict.pop(key) newkey = key.replace('wavenet', 'vocoder') if self.multigpu and 'module' not in key: newdict[newkey.replace('.', '.module.', 1)] = newdict.pop(newkey) newkey = newkey.replace('.', '.module.', 1) if newkey not in self.state_dict(): newdict.pop(newkey) self.load_state_dict(newdict) print("AutoVC Model Loaded") niter = initial_niter for epoch in range(epochs): self.train() for i, data in enumerate(dataloader): speaker_org, spec, prev, wav = data loss_dict, loss_dict_discriminator, loss_dict_wavenet = \ self.train_step(spec.to(device), speaker_org.to(device), prev=prev.to(device), wav=wav.to(device), device=device) if niter % display_freq == 0: print("Epoch[%d] Iter[%d] Niter[%d] %s %s %s" % (epoch, i, niter, loss_dict, loss_dict_discriminator, loss_dict_wavenet)) writer.add_scalars('data/Loss', loss_dict, niter) if loss_dict_discriminator != {}: writer.add_scalars('data/discriminator', loss_dict_discriminator, niter) if loss_dict_wavenet != {}: writer.add_scalars('data/wavenet', loss_dict_wavenet, niter) if niter % save_freq == 0: print("Saving and Testing...", end='\t') torch.save( self.state_dict(), save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' + str(niter).zfill(8) + ".pkl") # self.load_state_dict(torch.load('params.pkl')) if len(dataloader) >= 2: wav_dic, sr = self.test_fixed(device) for key, wav in wav_dic.items(): # print(wav.shape) writer.add_audio(key, wav, niter, sample_rate=sr) print("Done") self.train() torch.cuda.empty_cache() # Prevent Out of Memory niter += 1 def train_step(self, x, c_org, mask=None, mask_code=None, prev=None, wav=None, ret_content=False, retain_graph=False, device='cuda:0'): codes = self.encoder(x, c_org) # print(codes[0].shape) content = torch.cat([code.unsqueeze(1) for code in codes], dim=1) # print("content shape", content.shape) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) encoder_outputs = torch.cat( (code_exp, c_org.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1) mel_outputs = self.decoder(code_exp) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) loss_dict, loss_dict_discriminator, loss_dict_wavenet = {}, {}, {} loss_recon = self.criterionIdt(x, mel_outputs) loss_recon0 = self.criterionIdt(x, mel_outputs_postnet) loss_dict['recon'], loss_dict['recon0'] = loss_recon.data.item( ), loss_recon0.data.item() if self.loss_content: recons_codes = self.encoder(mel_outputs_postnet, c_org) recons_content = torch.cat( [code.unsqueeze(1) for code in recons_codes], dim=1) if mask is not None: loss_content = self.criterionIdt( content.masked_select(mask_code.byte()), recons_content.masked_select(mask_code.byte())) else: loss_content = self.criterionIdt(content, recons_content) loss_dict['content'] = loss_content.data.item() else: loss_content = torch.from_numpy(np.array(0)) loss_gen, loss_dis, loss_vocoder = [torch.from_numpy(np.array(0))] * 3 fake_mel = None if self.dis: # true_label = torch.from_numpy(np.ones(shape=(x.shape[0]))).to('cuda:0').long() # false_label = torch.from_numpy(np.zeros(shape=(x.shape[0]))).to('cuda:0').long() flip_speaker = 1 - c_org fake_mel = self.conversion(c_org, flip_speaker, x, device) loss_dis = self.dis_criterion(self.dis(x), True) + self.dis_criterion( self.dis(fake_mel), False) # + self.dis_criterion(self.dis(mel_outputs_postnet), False) self.opt_dis.zero_grad() loss_dis.backward(retain_graph=True) self.opt_dis.step() loss_gen = self.dis_criterion(self.dis(fake_mel), True) # + self.dis_criterion(self.dis(mel_outputs_postnet), True) loss_dict_discriminator['dis'], loss_dict_discriminator[ 'gen'] = loss_dis.data.item(), loss_gen.data.item() if not self.multigpu: y_hat = self.vocoder( prev, self.vocoder.pad_tensor(mel_outputs_postnet, hparams.voc_pad).transpose(1, 2)) else: y_hat = self.vocoder( prev, self.vocoder.module.pad_tensor(mel_outputs_postnet, hparams.voc_pad).transpose( 1, 2)) y_hat = y_hat.transpose(1, 2).unsqueeze(-1) # assert (0 <= wav < 2 ** 9).all() loss_vocoder = self.vocoder_loss_func(y_hat, wav.unsqueeze(-1).to(device)) self.opt_vocoder.zero_grad() Loss = loss_recon + loss_recon0 + loss_content + \ self.lambda_gan * loss_gen + self.lambda_wavenet * loss_vocoder loss_dict['total'] = Loss.data.item() self.opt_encoder.zero_grad() self.opt_decoder.zero_grad() Loss.backward(retain_graph=retain_graph) self.opt_encoder.step() self.opt_decoder.step() grad_norm = torch.nn.utils.clip_grad_norm_(self.vocoder.parameters(), 65504.0) self.opt_vocoder.step() if ret_content: return loss_recon, loss_recon0, loss_content, Loss, content return loss_dict, loss_dict_discriminator, loss_dict_wavenet
if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') print('Using device:', device) print('\nInitialising WaveRNN Model...\n') # Instantiate WaveRNN Model voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode='MOL').to(device) voc_model.load('vocoder/pretrained/voc_weights/latest_weights.pyt') # Get mel-spectrogram and generate wav file def generate(mel, filename, sampling_rate): save_path = 'output/' + filename # Scale mel-spectrogram new_mel = mel.clone().detach()
def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, decoder_type='simple', vocoder_type='wavenet', encoder_type='default', separate_encoder=True, loss_content=True, discriminator=False, dis_type='patch', multigpu=False, cycle=False, lambda_cycle=1, num_speakers=-1, idt_type='L2', use_lsgan=True, lambda_gan=0.0001, train_wavenet=False, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None, attention=False, residual=False): super(Generator, self).__init__() if encoder_type == 'default': self.encoder = Encoder(dim_neck, dim_emb, freq) elif encoder_type == 'nospeaker' or encoder_type == 'single': self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) elif encoder_type == 'multiencoder': self.encoder = MultiEncoder(num_speakers, dim_neck, freq, separate_encoder) if encoder_type == 'multiencoder' or encoder_type == 'single': self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) elif decoder_type == 'simple': self.decoder = Decoder(dim_neck, dim_emb, dim_pre) elif decoder_type == 'tacotron': self.decoder = TacotronDecoder(hparams) elif decoder_type == 'multidecoder': self.decoder = MultiDecoder(num_speakers, dim_neck, dim_pre, multigpu) elif decoder_type == 'video': # self.decoder = VideoGenerator() self.decoder = STAGE2_G(residual=residual) self.postnet = Postnet(num_mel=dim_spec) if discriminator: if dis_type == 'patch': self.dis = PatchDiscriminator(n_class=num_speakers) else: self.dis = SpeakerDiscriminator() # self.dis_criterion = nn.CrossEntropyLoss(reduction='mean') self.dis_criterion = GANLoss(use_lsgan=use_lsgan, tensor=torch.cuda.FloatTensor) else: self.dis = None self.encoder_type = encoder_type self.decoder_type = decoder_type self.vocoder_type = vocoder_type self.loss_content = loss_content self.cycle = cycle self.lambda_cycle = lambda_cycle self.lambda_gan = lambda_gan self.lambda_wavenet = lambda_wavenet self.attention = attention self.multigpu = multigpu self.train_vocoder = train_wavenet if self.train_vocoder: self.vocoder = WaveRNN( rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) if self.dis is not None: self.dis = nn.DataParallel(self.dis)
class Generator(nn.Module): """Generator network.""" def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, decoder_type='simple', vocoder_type='wavenet', encoder_type='default', separate_encoder=True, loss_content=True, discriminator=False, dis_type='patch', multigpu=False, cycle=False, lambda_cycle=1, num_speakers=-1, idt_type='L2', use_lsgan=True, lambda_gan=0.0001, train_wavenet=False, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None, attention=False, residual=False): super(Generator, self).__init__() if encoder_type == 'default': self.encoder = Encoder(dim_neck, dim_emb, freq) elif encoder_type == 'nospeaker' or encoder_type == 'single': self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) elif encoder_type == 'multiencoder': self.encoder = MultiEncoder(num_speakers, dim_neck, freq, separate_encoder) if encoder_type == 'multiencoder' or encoder_type == 'single': self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) elif decoder_type == 'simple': self.decoder = Decoder(dim_neck, dim_emb, dim_pre) elif decoder_type == 'tacotron': self.decoder = TacotronDecoder(hparams) elif decoder_type == 'multidecoder': self.decoder = MultiDecoder(num_speakers, dim_neck, dim_pre, multigpu) elif decoder_type == 'video': # self.decoder = VideoGenerator() self.decoder = STAGE2_G(residual=residual) self.postnet = Postnet(num_mel=dim_spec) if discriminator: if dis_type == 'patch': self.dis = PatchDiscriminator(n_class=num_speakers) else: self.dis = SpeakerDiscriminator() # self.dis_criterion = nn.CrossEntropyLoss(reduction='mean') self.dis_criterion = GANLoss(use_lsgan=use_lsgan, tensor=torch.cuda.FloatTensor) else: self.dis = None self.encoder_type = encoder_type self.decoder_type = decoder_type self.vocoder_type = vocoder_type self.loss_content = loss_content self.cycle = cycle self.lambda_cycle = lambda_cycle self.lambda_gan = lambda_gan self.lambda_wavenet = lambda_wavenet self.attention = attention self.multigpu = multigpu self.train_vocoder = train_wavenet if self.train_vocoder: self.vocoder = WaveRNN( rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) if self.dis is not None: self.dis = nn.DataParallel(self.dis) def forward(self, x, c_org, c_trg): codes = self.encoder(x, c_org) if c_trg is None: return torch.cat(codes, dim=-1) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand(-1, int(x.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) encoder_outputs = torch.cat( (code_exp, c_trg.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1) # (batch, T, 256+dim_neck) mel_outputs = self.decoder(encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) mel_outputs = mel_outputs.unsqueeze(1) mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1) return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1) def conversion(self, speaker_org, speaker_trg, spec, device, speed=1): speaker_org, speaker_trg, spec = speaker_org.to( device), speaker_trg.to(device), spec.to(device) if self.encoder_type == 'multiencoder': codes = self.encoder(spec, speaker_trg) else: if not self.multigpu: codes = self.encoder(spec, speaker_org) else: codes = self.encoder.module(spec, speaker_org) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand( -1, int(speed * spec.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) if self.attention: code_exp = self.phonemeToken(code_exp) encoder_outputs = torch.cat((code_exp, speaker_trg.unsqueeze(1).expand( -1, code_exp.size(1), -1)), dim=-1) if self.encoder_type == 'multiencoder' or self.encoder_type == 'single': mel_outputs = self.decoder( code_exp) if not self.multigpu else self.decoder.module( code_exp) elif self.decoder_type == 'simple': mel_outputs = self.decoder(encoder_outputs) elif self.decoder_type == 'tacotron': try: mel_outputs, _, alignments = self.decoder.inference( memory=encoder_outputs) except: mel_outputs, _, alignments = self.decoder.module.inference( memory=encoder_outputs) mel_outputs.transpose_(1, 2) elif self.decoder_type == 'multidecoder': mel_outputs = self.decoder(code_exp, speaker_trg) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) return mel_outputs_postnet
def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, vocoder_type='wavenet', multigpu=False, num_speakers=-1, idt_type='L2', train_wavenet=False, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None, residual=False, attention_map=None, use_256=False): super(VideoAudioGenerator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if use_256: self.video_decoder = VideoGenerator(use_256=True) else: self.video_decoder = STAGE2_G(residual=residual) self.use_256 = use_256 self.vocoder_type = vocoder_type self.lambda_wavenet = lambda_wavenet self.multigpu = multigpu # self.prepare_test(dim_spec, test_path_source, test_path_target) self.train_vocoder = train_wavenet if self.train_vocoder: if vocoder_type == 'wavenet' or vocoder_type == 'griffin': self.vocoder = WaveRNN( rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if attention_map is not None: self.attention_map_large = np.load(attention_map) self.attention_map = cv2.resize(self.attention_map_large, dsize=(128, 128), interpolation=cv2.INTER_CUBIC) # self.attention_map_large = self.attention_map_large.astype(np.float64) # self.attention_map = self.attention_map.astype(np.float64) self.attention_map_large = torch.from_numpy( self.attention_map_large / self.attention_map_large.max()).float() self.attention_map = torch.from_numpy( self.attention_map / self.attention_map.max()).float() self.criterionVideo = torch.nn.L1Loss(reduction='none') else: self.attention_map = None if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.video_decoder = nn.DataParallel(self.video_decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder)
class VideoAudioGenerator(nn.Module): def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, vocoder_type='wavenet', multigpu=False, num_speakers=-1, idt_type='L2', train_wavenet=False, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None, residual=False, attention_map=None, use_256=False): super(VideoAudioGenerator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if use_256: self.video_decoder = VideoGenerator(use_256=True) else: self.video_decoder = STAGE2_G(residual=residual) self.use_256 = use_256 self.vocoder_type = vocoder_type self.lambda_wavenet = lambda_wavenet self.multigpu = multigpu # self.prepare_test(dim_spec, test_path_source, test_path_target) self.train_vocoder = train_wavenet if self.train_vocoder: if vocoder_type == 'wavenet' or vocoder_type == 'griffin': self.vocoder = WaveRNN( rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if attention_map is not None: self.attention_map_large = np.load(attention_map) self.attention_map = cv2.resize(self.attention_map_large, dsize=(128, 128), interpolation=cv2.INTER_CUBIC) # self.attention_map_large = self.attention_map_large.astype(np.float64) # self.attention_map = self.attention_map.astype(np.float64) self.attention_map_large = torch.from_numpy( self.attention_map_large / self.attention_map_large.max()).float() self.attention_map = torch.from_numpy( self.attention_map / self.attention_map.max()).float() self.criterionVideo = torch.nn.L1Loss(reduction='none') else: self.attention_map = None if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.video_decoder = nn.DataParallel(self.video_decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) def generate(self, mel, speaker, device='cuda:0'): mel, speaker = mel.to(device), speaker.to(device) if not self.multigpu: codes, code_unsample = self.encoder(mel, speaker, return_unsample=True) else: codes, code_unsample = self.encoder.module(mel, speaker, return_unsample=True) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) if not self.multigpu: if not self.use_256: v_stage1, v_stage2 = self.video_decoder(code_unsample, train=True) else: v_stage2 = self.video_decoder(code_unsample) v_stage1 = v_stage2 mel_outputs = self.decoder(code_exp) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) else: if not self.use_256: v_stage1, v_stage2 = self.video_decoder.module(code_unsample, train=True) else: v_stage2 = self.video_decoder.module(code_unsample) v_stage1 = v_stage2 mel_outputs = self.decoder.module(code_exp) mel_outputs_postnet = self.postnet.module( mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) return mel_outputs_postnet, v_stage1, v_stage2
def train(run_id='', syn_dir=None, voc_dirs=[], mel_dir_name='', models_dir=None, log_dir='', ground_truth=False, save_every=1000, backup_every=1000, log_every=1000, force_restart=False, total_epochs=10000, logger=None): # Check to make sure the hop length is correctly factorised assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length # Instantiate the model print("Initializing the model...") model = WaveRNN( rnn_dims=hp.voc_rnn_dims, # 512 fc_dims=hp.voc_fc_dims, # 512 bits=hp.bits, # 9 pad=hp.voc_pad, # 2 upsample_factors=hp.voc_upsample_factors, # (3, 4, 5, 5) -> 300, (5,5,12)? feat_dims=hp.num_mels, # 80 compute_dims=hp.voc_compute_dims, # 128 res_out_dims=hp.voc_res_out_dims, # 128 res_blocks=hp.voc_res_blocks, # 10 hop_length=hp.hop_length, # 300 sample_rate=hp.sample_rate, # 24000 mode=hp.voc_mode # RAW (or MOL) ).cuda() # hp.apply_preemphasis in VocoderDataset # hp.mu_law in VocoderDataset # hp.voc_seq_len in VocoderDataset # hp.voc_lr in optimizer # hp.voc_batch_size for train # Initialize the optimizer optimizer = optim.Adam(model.parameters()) for p in optimizer.param_groups: p["lr"] = hp.voc_lr # 0.0001 loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss # Load the weights model_dir = models_dir.joinpath(run_id) # gta_model/gtaxxxx model_dir.mkdir(exist_ok=True) weights_fpath = model_dir.joinpath(run_id + ".pt") # gta_model/gtaxxx/gtaxxx.pt if force_restart or not weights_fpath.exists(): print("\nStarting the training of WaveRNN from scratch\n") model.save(str(weights_fpath), optimizer) else: print("\nLoading weights at %s" % weights_fpath) model.load(str(weights_fpath), optimizer) print("WaveRNN weights loaded from step %d" % model.step) # Initialize the dataset #metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \ # voc_dir.joinpath("synthesized.txt") #mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath("mels_gta") #wav_dir = syn_dir.joinpath("audio") #dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir) #dataset = VocoderDataset(str(voc_dir), 'mels-gta-1099579078086', 'audio') dataset = VocoderDataset([str(voc_dir) for voc_dir in voc_dirs], mel_dir_name, 'audio') #test_loader = DataLoader(dataset, # batch_size=1, # shuffle=True, # pin_memory=True) # Begin the training simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr), ('Sequence Len', hp.voc_seq_len)]) for epoch in range(1, total_epochs): data_loader = DataLoader(dataset, collate_fn=collate_vocoder, batch_size=hp.voc_batch_size, num_workers=30, shuffle=True, pin_memory=True) start = time.time() running_loss = 0. # start from 1 for i, (x, y, m) in enumerate(data_loader, 1): # cur [B, L], future [B, L] bit label, mels [B, D, T] x, m, y = x.cuda(), m.cuda(), y.cuda() # Forward pass # [B, L], [B, D, T] -> [B, L, C] y_hat = model(x, m) if model.mode == 'RAW': # [B, L, C] -> [B, C, L, 1] y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() # [B, L, 1] y = y.unsqueeze(-1) # Backward pass # [B, C, L, 1], [B, L, 1] # cross_entropy for RAW loss = loss_func(y_hat, y) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() speed = i / (time.time() - start) avg_loss = running_loss / i step = model.get_step() k = step // 1000 if backup_every != 0 and step % backup_every == 0 : model.checkpoint(str(model_dir), optimizer) if save_every != 0 and step % save_every == 0 : model.save(str(weights_fpath), optimizer) if log_every != 0 and step % log_every == 0 : logger.scalar_summary("loss", loss.item(), step) total_data=len(data_loader) speed=speed avg_loss=avg_loss k=k total_data=total_data msg = ("| Epoch: {epoch} ({i}/{total_data}) | " +\ "Loss: {avg_loss:.4f} | {speed:.1f} " +\ "steps/s | Step: {k}k | ").format(**vars()) stream(msg) #gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, # hp.voc_target, hp.voc_overlap, model_dir) print("")
def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, loss_content=True, discriminator=False, multigpu=False, lambda_gan=0.0001, lambda_wavenet=0.001, args=None, test_path_source=None, test_path_target=None): super(Generator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if discriminator: self.dis = PatchDiscriminator(n_class=num_speakers) self.dis_criterion = GANLoss(use_lsgan=use_lsgan, tensor=torch.cuda.FloatTensor) else: self.dis = None self.loss_content = loss_content self.lambda_gan = lambda_gan self.lambda_wavenet = lambda_wavenet self.multigpu = multigpu self.prepare_test(dim_spec, test_path_source, test_path_target) self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) if is_train: self.criterionIdt = torch.nn.L1Loss(reduction='mean') self.opt_encoder = torch.optim.Adam(self.encoder.parameters(), lr=lr) self.opt_decoder = torch.optim.Adam(itertools.chain( self.decoder.parameters(), self.postnet.parameters()), lr=lr) if discriminator: self.opt_dis = torch.optim.Adam(self.dis.parameters(), lr=lr) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) if self.dis is not None: self.dis = nn.DataParallel(self.dis)
def main(): # Parse Arguments parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder') parser.add_argument('--lr', '-l', type=float, help='[float] override hparams.py learning rate') parser.add_argument('--batch_size', '-b', type=int, help='[int] override hparams.py batch size') parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps') parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features') parser.add_argument( '--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment') parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') args = parser.parse_args() # Set hyperparameters hp.training_files = "tacotron2/filelists/transcripts_korean_final_final.txt" hp.validation_files = "tacotron2/filelists/transcripts_korean_final_validate.txt" hp.filter_length = 1024 hp.n_mel_channels = 80 hp.sampling_rate = 16000 hp.mel_fmin = 0.0 hp.mel_fmax = 8000.0 hp.max_wav_value = 32768.0 hp.n_frames_per_step = 1 hp.configure(args.hp_file) # load hparams from file if args.lr is None: args.lr = hp.voc_lr if args.batch_size is None: args.batch_size = hp.voc_batch_size paths = Paths("../data/", hp.voc_model_id, hp.tts_model_id) batch_size = 64 force_train = args.force_train train_gta = args.gta lr = args.lr if not args.force_cpu and torch.cuda.is_available(): device = torch.device('cuda') if batch_size % torch.cuda.device_count() != 0: raise ValueError( '`batch_size` must be evenly divisible by n_gpus!') else: device = torch.device('cpu') print('Using device:', device) print('\nInitialising Model...\n') # Instantiate WaveRNN Model voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).to(device) # Check to make sure the hop length is correctly factorised assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length optimizer = optim.Adam(voc_model.parameters()) restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True) train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta, hp) total_steps = 10_000_000 if force_train else hp.voc_total_steps simple_table([ ('Remaining', str( (total_steps - voc_model.get_step()) // 1000) + 'k Steps'), ('Batch Size', batch_size), ('LR', lr), ('Sequence Len', hp.voc_seq_len), ('GTA Train', train_gta) ]) loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps) print('Training Complete.') print( 'To continue training increase voc_total_steps in hparams.py or use --force_train' )
class VideoAudioGenerator(nn.Module): def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, multigpu=False, lambda_wavenet=0.001, args=None, residual=False, attention_map=None, use_256=False, loss_content=False, test_path=None): super(VideoAudioGenerator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if use_256: self.video_decoder = VideoGenerator(use_256=True) else: self.video_decoder = STAGE2_G(residual=residual) self.use_256 = use_256 self.lambda_wavenet = lambda_wavenet self.loss_content = loss_content self.multigpu = multigpu self.test_path = test_path self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) if is_train: self.criterionIdt = torch.nn.L1Loss(reduction='mean') self.opt_encoder = torch.optim.Adam(self.encoder.parameters(), lr=lr) self.opt_decoder = torch.optim.Adam(itertools.chain( self.decoder.parameters(), self.postnet.parameters()), lr=lr) self.opt_video_decoder = torch.optim.Adam( self.video_decoder.parameters(), lr=lr) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.video_decoder = nn.DataParallel(self.video_decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder) def optimize_parameters_video(self, dataloader, epochs, device, display_freq=10, save_freq=1000, save_dir="./", experimentName="Train", initial_niter=0, load_model=None): writer = SummaryWriter(log_dir="logs/" + experimentName) if load_model is not None: print("Loading from %s..." % load_model) # self.load_state_dict(torch.load(load_model)) d = torch.load(load_model) newdict = d.copy() for key, value in d.items(): newkey = key if 'wavenet' in key: newdict[key.replace('wavenet', 'vocoder')] = newdict.pop(key) newkey = key.replace('wavenet', 'vocoder') if self.multigpu and 'module' not in key: newdict[newkey.replace('.', '.module.', 1)] = newdict.pop(newkey) newkey = newkey.replace('.', '.module.', 1) if newkey not in self.state_dict(): newdict.pop(newkey) print("Load " + str(len(newdict)) + " parameters!") self.load_state_dict(newdict, strict=False) print("AutoVC Model Loaded") niter = initial_niter for epoch in range(epochs): self.train() for i, data in enumerate(dataloader): # print("Processing ..." + str(name)) speaker, mel, prev, wav, video, video_large = data speaker, mel, prev, wav, video, video_large = speaker.to( device), mel.to(device), prev.to(device), wav.to( device), video.to(device), video_large.to(device) codes, code_unsample = self.encoder(mel, speaker, return_unsample=True) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) if not self.use_256: v_stage1, v_stage2 = self.video_decoder(code_unsample, train=True) else: v_stage2 = self.video_decoder(code_unsample) mel_outputs = self.decoder(code_exp) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose( 2, 1) if self.loss_content: _, recons_codes = self.encoder(mel_outputs_postnet, speaker, return_unsample=True) loss_content = self.criterionIdt(code_unsample, recons_codes) else: loss_content = torch.from_numpy(np.array(0)) if not self.use_256: loss_video = self.criterionIdt(v_stage1, video) + self.criterionIdt( v_stage2, video_large) else: loss_video = self.criterionIdt(v_stage2, video_large) loss_recon = self.criterionIdt(mel, mel_outputs) loss_recon0 = self.criterionIdt(mel, mel_outputs_postnet) loss_vocoder = 0 if not self.multigpu: y_hat = self.vocoder( prev, self.vocoder.pad_tensor(mel_outputs_postnet, hparams.voc_pad).transpose( 1, 2)) else: y_hat = self.vocoder( prev, self.vocoder.module.pad_tensor( mel_outputs_postnet, hparams.voc_pad).transpose(1, 2)) y_hat = y_hat.transpose(1, 2).unsqueeze(-1) # assert (0 <= wav < 2 ** 9).all() loss_vocoder = self.vocoder_loss_func( y_hat, wav.unsqueeze(-1).to(device)) self.opt_vocoder.zero_grad() loss = loss_video + loss_recon + loss_recon0 + self.lambda_wavenet * loss_vocoder + loss_content self.opt_encoder.zero_grad() self.opt_decoder.zero_grad() self.opt_video_decoder.zero_grad() loss.backward() self.opt_encoder.step() self.opt_decoder.step() self.opt_video_decoder.step() self.opt_vocoder.step() if niter % display_freq == 0: print("Epoch[%d] Iter[%d] Niter[%d] %s" % (epoch, i, niter, loss.data.item())) writer.add_scalars( 'data/Loss', { 'loss': loss.data.item(), 'loss_video': loss_video.data.item(), 'loss_audio': loss_recon0.data.item() + loss_recon.data.item() }, niter) if niter % save_freq == 0: torch.cuda.empty_cache() # Prevent Out of Memory print("Saving and Testing...", end='\t') torch.save( self.state_dict(), save_dir + '/Epoch' + str(epoch).zfill(3) + '_Iter' + str(niter).zfill(8) + ".pkl") # self.load_state_dict(torch.load('params.pkl')) self.test_audiovideo(device, writer, niter) print("Done") self.train() torch.cuda.empty_cache() # Prevent Out of Memory niter += 1 def generate(self, mel, speaker, device='cuda:0'): mel, speaker = mel.to(device), speaker.to(device) if not self.multigpu: codes, code_unsample = self.encoder(mel, speaker, return_unsample=True) else: codes, code_unsample = self.encoder.module(mel, speaker, return_unsample=True) tmp = [] for code in codes: tmp.append( code.unsqueeze(1).expand(-1, int(mel.size(1) / len(codes)), -1)) code_exp = torch.cat(tmp, dim=1) if not self.multigpu: if not self.use_256: v_stage1, v_stage2 = self.video_decoder(code_unsample, train=True) else: v_stage2 = self.video_decoder(code_unsample) v_stage1 = v_stage2 mel_outputs = self.decoder(code_exp) mel_outputs_postnet = self.postnet(mel_outputs.transpose(2, 1)) else: if not self.use_256: v_stage1, v_stage2 = self.video_decoder.module(code_unsample, train=True) else: v_stage2 = self.video_decoder.module(code_unsample) v_stage1 = v_stage2 mel_outputs = self.decoder.module(code_exp) mel_outputs_postnet = self.postnet.module( mel_outputs.transpose(2, 1)) mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2, 1) return mel_outputs_postnet, v_stage1, v_stage2 def test_video(self, device): wav, sr = librosa.load( "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4", hparams.sample_rate) mel_basis = librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels) linear_spec = np.abs( librosa.stft(wav, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size)) mel_spec = mel_basis.dot(linear_spec) mel_db = 20 * np.log10(mel_spec) test_data = np.clip((mel_db + 120) / 125, 0, 1) test_data = torch.Tensor(pad_seq(test_data.T, hparams.freq)).unsqueeze(0).to(device) with torch.no_grad(): codes, code_exp = self.encoder.module(test_data, return_unsample=True) v_mid, v_hat = self.video_decoder.module(code_exp, train=True) reader = imageio.get_reader( "/mnt/lustre/dengkangle/cmu/datasets/video/obama_test.mp4", 'ffmpeg', fps=20) frames = [] for i, im in enumerate(reader): frames.append(np.array(im).transpose(2, 0, 1)) frames = (np.array(frames) / 255 - 0.5) / 0.5 return frames, v_mid[0:1], v_hat[0:1] def test_audiovideo(self, device, writer, niter): source_path = self.test_path mel_basis80 = librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=80) wav, sr = librosa.load(source_path, hparams.sample_rate) wav = preemphasis(wav, hparams.preemphasis, hparams.preemphasize) linear_spec = np.abs( librosa.stft(wav, n_fft=hparams.n_fft, hop_length=hparams.hop_size, win_length=hparams.win_size)) mel_spec = mel_basis80.dot(linear_spec) mel_db = 20 * np.log10(mel_spec) source_spec = np.clip((mel_db + 120) / 125, 0, 1) source_embed = torch.from_numpy(np.array([0, 1])).float().unsqueeze(0) source_wav = wav source_spec = torch.Tensor(pad_seq(source_spec.T, hparams.freq)).unsqueeze(0) # print(source_spec.shape) with torch.no_grad(): generated_spec, v_mid, v_hat = self.generate( source_spec, source_embed, device) generated_spec, v_mid, v_hat = generated_spec.cpu(), v_mid.cpu( ), v_hat.cpu() print("Generating Wavfile...") with torch.no_grad(): if not self.multigpu: generated_wav = inv_preemphasis( self.vocoder.generate(generated_spec.to(device).transpose( 2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) else: generated_wav = inv_preemphasis( self.vocoder.module.generate( generated_spec.to(device).transpose(2, 1), False, None, None, mu_law=True), hparams.preemphasis, hparams.preemphasize) writer.add_video('generated', (v_hat.numpy() + 1) / 2, global_step=niter) writer.add_video('mid', (v_mid.numpy() + 1) / 2, global_step=niter) writer.add_audio('ground_truth', source_wav, niter, sample_rate=hparams.sample_rate) writer.add_audio('generated_wav', generated_wav, niter, sample_rate=hparams.sample_rate)
def train(run_id: str, syn_dir: Path, voc_dir: Path, models_dir: Path, ground_truth: bool, save_every: int, backup_every: int, force_restart: bool): # Check to make sure the hop length is correctly factorised assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length # Instantiate the model print("Initializing the model...") model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).cuda() # Initialize the optimizer optimizer = optim.Adam(model.parameters()) for p in optimizer.param_groups: p["lr"] = hp.voc_lr loss_func = F.cross_entropy if model.mode == "RAW" else discretized_mix_logistic_loss # Load the weights model_dir = models_dir.joinpath(run_id) model_dir.mkdir(exist_ok=True) weights_fpath = model_dir.joinpath(run_id + ".pt") if force_restart or not weights_fpath.exists(): print("\nStarting the training of WaveRNN from scratch\n") model.save(weights_fpath, optimizer) else: print("\nLoading weights at %s" % weights_fpath) model.load(weights_fpath, optimizer) print("WaveRNN weights loaded from step %d" % model.step) # Initialize the dataset metadata_fpath = syn_dir.joinpath("train.txt") if ground_truth else \ voc_dir.joinpath("synthesized.txt") mel_dir = syn_dir.joinpath("mels") if ground_truth else voc_dir.joinpath( "mels_gta") wav_dir = syn_dir.joinpath("audio") dataset = VocoderDataset(metadata_fpath, mel_dir, wav_dir) test_loader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True) # Begin the training simple_table([('Batch size', hp.voc_batch_size), ('LR', hp.voc_lr), ('Sequence Len', hp.voc_seq_len)]) for epoch in range(1, 350): data_loader = DataLoader(dataset, collate_fn=collate_vocoder, batch_size=hp.voc_batch_size, num_workers=2, shuffle=True, pin_memory=True) start = time.time() running_loss = 0. for i, (x, y, m) in enumerate(data_loader, 1): x, m, y = x.cuda(), m.cuda(), y.cuda() # Forward pass y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) print("y shape:", y.shape) print("y_hat shape:", y_hat.shape) # Backward pass loss = loss_func(y_hat, y) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() speed = i / (time.time() - start) avg_loss = running_loss / i step = model.get_step() k = step // 1000 if backup_every != 0 and step % backup_every == 0: model.checkpoint(model_dir, optimizer) if save_every != 0 and step % save_every == 0: model.save(weights_fpath, optimizer) msg = f"| Epoch: {epoch} ({i}/{len(data_loader)}) | " \ f"Loss: {avg_loss:.4f} | {speed:.1f} " \ f"steps/s | Step: {k}k | " stream(msg) gen_testset(model, test_loader, hp.voc_gen_at_checkpoint, hp.voc_gen_batched, hp.voc_target, hp.voc_overlap, model_dir) print("")
def __init__(self, dim_neck, dim_emb, dim_pre, freq, dim_spec=80, is_train=False, lr=0.001, multigpu=False, lambda_wavenet=0.001, args=None, residual=False, attention_map=None, use_256=False, loss_content=False, test_path=None): super(VideoAudioGenerator, self).__init__() self.encoder = MyEncoder(dim_neck, freq, num_mel=dim_spec) self.decoder = Decoder(dim_neck, 0, dim_pre, num_mel=dim_spec) self.postnet = Postnet(num_mel=dim_spec) if use_256: self.video_decoder = VideoGenerator(use_256=True) else: self.video_decoder = STAGE2_G(residual=residual) self.use_256 = use_256 self.lambda_wavenet = lambda_wavenet self.loss_content = loss_content self.multigpu = multigpu self.test_path = test_path self.vocoder = WaveRNN(rnn_dims=hparams.voc_rnn_dims, fc_dims=hparams.voc_fc_dims, bits=hparams.bits, pad=hparams.voc_pad, upsample_factors=hparams.voc_upsample_factors, feat_dims=hparams.num_mels, compute_dims=hparams.voc_compute_dims, res_out_dims=hparams.voc_res_out_dims, res_blocks=hparams.voc_res_blocks, hop_length=hparams.hop_size, sample_rate=hparams.sample_rate, mode=hparams.voc_mode) if is_train: self.criterionIdt = torch.nn.L1Loss(reduction='mean') self.opt_encoder = torch.optim.Adam(self.encoder.parameters(), lr=lr) self.opt_decoder = torch.optim.Adam(itertools.chain( self.decoder.parameters(), self.postnet.parameters()), lr=lr) self.opt_video_decoder = torch.optim.Adam( self.video_decoder.parameters(), lr=lr) self.opt_vocoder = torch.optim.Adam(self.vocoder.parameters(), lr=hparams.voc_lr) self.vocoder_loss_func = F.cross_entropy # Only for RAW if multigpu: self.encoder = nn.DataParallel(self.encoder) self.decoder = nn.DataParallel(self.decoder) self.video_decoder = nn.DataParallel(self.video_decoder) self.postnet = nn.DataParallel(self.postnet) self.vocoder = nn.DataParallel(self.vocoder)