def test(self): for i, (real_A) in enumerate(tqdm(self.test_dataloader)): real_A = real_A.to(self.device, dtype=torch.float) fake_B = self.generator_A2B(real_A, torch.ones_like(real_A)) wav_fake_B = decode_melspectrogram(self.vocoder, fake_B[0].detach().cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() save_path = None if self.model_name == 'generator_A2B': save_path = os.path.join( self.converted_audio_dir, f"converted_{self.speaker_A_id}_to_{self.speaker_B_id}{i}.wav" ) else: save_path = os.path.join( self.converted_audio_dir, f"converted_{self.speaker_B_id}_to_{self.speaker_A_id}{i}.wav" ) torchaudio.save(save_path, wav_fake_B, sample_rate=self.sample_rate)
def train(self): """Implements the training loop for MaskCycleGAN-VC """ for epoch in range(self.start_epoch, self.num_epochs + 1): self.logger.start_epoch() for i, (real_A, mask_A, real_B, mask_B) in enumerate(tqdm(self.train_dataloader)): self.logger.start_iter() num_iterations = ( self.n_samples // self.mini_batch_size) * epoch + i real_A = real_A.to(self.device, dtype=torch.float) mask_A = mask_A.to(self.device, dtype=torch.float) real_B = real_B.to(self.device, dtype=torch.float) mask_B = mask_B.to(self.device, dtype=torch.float) # Train Generator # Generator Feed Forward fake_B = self.generator_A2B(real_A, mask_A) cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B)) fake_A = self.generator_B2A(real_B, mask_B) cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A)) identity_A = self.generator_B2A( real_A, torch.ones_like(real_A)) identity_B = self.generator_A2B( real_B, torch.ones_like(real_B)) d_fake_A = self.discriminator_A(fake_A) d_fake_B = self.discriminator_B(fake_B) # For Two Step Adverserial Loss d_fake_cycle_A = self.discriminator_A2(cycle_A) d_fake_cycle_B = self.discriminator_B2(cycle_B) # Generator Cycle Loss cycleLoss = torch.mean( torch.abs(real_A - cycle_A)) + torch.mean(torch.abs(real_B - cycle_B)) # Generator Identity Loss identityLoss = torch.mean( torch.abs(real_A - identity_A)) + torch.mean(torch.abs(real_B - identity_B)) # Generator Loss g_loss_A2B = torch.mean((1 - d_fake_B) ** 2) g_loss_B2A = torch.mean((1 - d_fake_A) ** 2) # Generator Two Step Adverserial Loss generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B) ** 2) generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A) ** 2) # Total Generator Loss g_loss = g_loss_A2B + g_loss_B2A + \ generator_loss_A2B_2nd + generator_loss_B2A_2nd + \ self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss # Backprop for Generator self.reset_grad() g_loss.backward() self.generator_optimizer.step() # Train Discriminator # Discriminator Feed Forward d_real_A = self.discriminator_A(real_A) d_real_B = self.discriminator_B(real_B) d_real_A2 = self.discriminator_A2(real_A) d_real_B2 = self.discriminator_B2(real_B) generated_A = self.generator_B2A(real_B, mask_B) d_fake_A = self.discriminator_A(generated_A) # For Two Step Adverserial Loss A->B cycled_B = self.generator_A2B( generated_A, torch.ones_like(generated_A)) d_cycled_B = self.discriminator_B2(cycled_B) generated_B = self.generator_A2B(real_A, mask_A) d_fake_B = self.discriminator_B(generated_B) # For Two Step Adverserial Loss B->A cycled_A = self.generator_B2A( generated_B, torch.ones_like(generated_B)) d_cycled_A = self.discriminator_A2(cycled_A) # Loss Functions d_loss_A_real = torch.mean((1 - d_real_A) ** 2) d_loss_A_fake = torch.mean((0 - d_fake_A) ** 2) d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0 d_loss_B_real = torch.mean((1 - d_real_B) ** 2) d_loss_B_fake = torch.mean((0 - d_fake_B) ** 2) d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0 # Two Step Adverserial Loss d_loss_A_cycled = torch.mean((0 - d_cycled_A) ** 2) d_loss_B_cycled = torch.mean((0 - d_cycled_B) ** 2) d_loss_A2_real = torch.mean((1 - d_real_A2) ** 2) d_loss_B2_real = torch.mean((1 - d_real_B2) ** 2) d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0 d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0 # Final Loss for discriminator with the Two Step Adverserial Loss d_loss = (d_loss_A + d_loss_B) / 2.0 + \ (d_loss_A_2nd + d_loss_B_2nd) / 2.0 # Backprop for Discriminator self.reset_grad() d_loss.backward() self.discriminator_optimizer.step() # Log Iteration self.logger.log_iter( loss_dict={'g_loss': g_loss.item(), 'd_loss': d_loss.item()}) self.logger.end_iter() # Adjust learning rates if self.logger.global_step > self.decay_after: self.identity_loss_lambda = 0 self.adjust_lr_rate( self.generator_optimizer, generator=True) self.adjust_lr_rate( self.generator_optimizer, generator=False) # Log intermediate outputs on Tensorboard if self.logger.epoch % self.epochs_per_plot == 0: # Log Mel-spectrograms .png real_mel_A_fig = get_mel_spectrogram_fig( real_A[0].detach().cpu()) fake_mel_A_fig = get_mel_spectrogram_fig( generated_A[0].detach().cpu()) real_mel_B_fig = get_mel_spectrogram_fig( real_B[0].detach().cpu()) fake_mel_B_fig = get_mel_spectrogram_fig( generated_B[0].detach().cpu()) self.logger.visualize_outputs({"real_voc_spec": real_mel_A_fig, "fake_coraal_spec": fake_mel_B_fig, "real_coraal_spec": real_mel_B_fig, "fake_voc_spec": fake_mel_A_fig}) # Convert Mel-spectrograms from validation set to waveform and log to tensorboard real_mel_full_A, real_mel_full_B = next( iter(self.validation_dataloader)) real_mel_full_A = real_mel_full_A.to( self.device, dtype=torch.float) real_mel_full_B = real_mel_full_B.to( self.device, dtype=torch.float) fake_mel_full_B = self.generator_A2B( real_mel_full_A, torch.ones_like(real_mel_full_A)) fake_mel_full_A = self.generator_B2A( real_mel_full_B, torch.ones_like(real_mel_full_B)) real_wav_full_A = decode_melspectrogram(self.vocoder, real_mel_full_A[0].detach( ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() fake_wav_full_A = decode_melspectrogram(self.vocoder, fake_mel_full_A[0].detach( ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu() real_wav_full_B = decode_melspectrogram(self.vocoder, real_mel_full_B[0].detach( ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() fake_wav_full_B = decode_melspectrogram(self.vocoder, fake_mel_full_B[0].detach( ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu() self.logger.log_audio( real_wav_full_A.T, "real_speaker_A_audio", self.sample_rate) self.logger.log_audio( fake_wav_full_A.T, "fake_speaker_A_audio", self.sample_rate) self.logger.log_audio( real_wav_full_B.T, "real_speaker_B_audio", self.sample_rate) self.logger.log_audio( fake_wav_full_B.T, "fake_speaker_B_audio", self.sample_rate) # Save each model checkpoint if self.logger.epoch % self.epochs_per_save == 0: self.saver.save(self.logger.epoch, self.generator_A2B, self.generator_optimizer, None, args.device, "generator_A2B") self.saver.save(self.logger.epoch, self.generator_B2A, self.generator_optimizer, None, args.device, "generator_B2A") self.saver.save(self.logger.epoch, self.discriminator_A, self.discriminator_optimizer, None, args.device, "discriminator_A") self.saver.save(self.logger.epoch, self.discriminator_B, self.discriminator_optimizer, None, args.device, "discriminator_B") self.saver.save(self.logger.epoch, self.discriminator_A2, self.discriminator_optimizer, None, args.device, "discriminator_A2") self.saver.save(self.logger.epoch, self.discriminator_B2, self.discriminator_optimizer, None, args.device, "discriminator_B2") self.logger.end_epoch()