예제 #1
0
 def test(self):
     for i, (real_A) in enumerate(tqdm(self.test_dataloader)):
         real_A = real_A.to(self.device, dtype=torch.float)
         fake_B = self.generator_A2B(real_A, torch.ones_like(real_A))
         wav_fake_B = decode_melspectrogram(self.vocoder,
                                            fake_B[0].detach().cpu(),
                                            self.dataset_A_mean,
                                            self.dataset_A_std).cpu()
         save_path = None
         if self.model_name == 'generator_A2B':
             save_path = os.path.join(
                 self.converted_audio_dir,
                 f"converted_{self.speaker_A_id}_to_{self.speaker_B_id}{i}.wav"
             )
         else:
             save_path = os.path.join(
                 self.converted_audio_dir,
                 f"converted_{self.speaker_B_id}_to_{self.speaker_A_id}{i}.wav"
             )
         torchaudio.save(save_path,
                         wav_fake_B,
                         sample_rate=self.sample_rate)
예제 #2
0
    def train(self):
        """Implements the training loop for MaskCycleGAN-VC
        """
        for epoch in range(self.start_epoch, self.num_epochs + 1):
            self.logger.start_epoch()

            for i, (real_A, mask_A, real_B, mask_B) in enumerate(tqdm(self.train_dataloader)):
                self.logger.start_iter()
                num_iterations = (
                    self.n_samples // self.mini_batch_size) * epoch + i

                real_A = real_A.to(self.device, dtype=torch.float)
                mask_A = mask_A.to(self.device, dtype=torch.float)
                real_B = real_B.to(self.device, dtype=torch.float)
                mask_B = mask_B.to(self.device, dtype=torch.float)

                # Train Generator

                # Generator Feed Forward
                fake_B = self.generator_A2B(real_A, mask_A)
                cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B))
                fake_A = self.generator_B2A(real_B, mask_B)
                cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A))
                identity_A = self.generator_B2A(
                    real_A, torch.ones_like(real_A))
                identity_B = self.generator_A2B(
                    real_B, torch.ones_like(real_B))
                d_fake_A = self.discriminator_A(fake_A)
                d_fake_B = self.discriminator_B(fake_B)

                # For Two Step Adverserial Loss
                d_fake_cycle_A = self.discriminator_A2(cycle_A)
                d_fake_cycle_B = self.discriminator_B2(cycle_B)

                # Generator Cycle Loss
                cycleLoss = torch.mean(
                    torch.abs(real_A - cycle_A)) + torch.mean(torch.abs(real_B - cycle_B))

                # Generator Identity Loss
                identityLoss = torch.mean(
                    torch.abs(real_A - identity_A)) + torch.mean(torch.abs(real_B - identity_B))

                # Generator Loss
                g_loss_A2B = torch.mean((1 - d_fake_B) ** 2)
                g_loss_B2A = torch.mean((1 - d_fake_A) ** 2)

                # Generator Two Step Adverserial Loss
                generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B) ** 2)
                generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A) ** 2)

                # Total Generator Loss
                g_loss = g_loss_A2B + g_loss_B2A + \
                    generator_loss_A2B_2nd + generator_loss_B2A_2nd + \
                    self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss

                # Backprop for Generator
                self.reset_grad()
                g_loss.backward()
                self.generator_optimizer.step()

                # Train Discriminator

                # Discriminator Feed Forward
                d_real_A = self.discriminator_A(real_A)
                d_real_B = self.discriminator_B(real_B)
                d_real_A2 = self.discriminator_A2(real_A)
                d_real_B2 = self.discriminator_B2(real_B)
                generated_A = self.generator_B2A(real_B, mask_B)
                d_fake_A = self.discriminator_A(generated_A)

                # For Two Step Adverserial Loss A->B
                cycled_B = self.generator_A2B(
                    generated_A, torch.ones_like(generated_A))
                d_cycled_B = self.discriminator_B2(cycled_B)

                generated_B = self.generator_A2B(real_A, mask_A)
                d_fake_B = self.discriminator_B(generated_B)

                # For Two Step Adverserial Loss B->A
                cycled_A = self.generator_B2A(
                    generated_B, torch.ones_like(generated_B))
                d_cycled_A = self.discriminator_A2(cycled_A)

                # Loss Functions
                d_loss_A_real = torch.mean((1 - d_real_A) ** 2)
                d_loss_A_fake = torch.mean((0 - d_fake_A) ** 2)
                d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0

                d_loss_B_real = torch.mean((1 - d_real_B) ** 2)
                d_loss_B_fake = torch.mean((0 - d_fake_B) ** 2)
                d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0

                # Two Step Adverserial Loss
                d_loss_A_cycled = torch.mean((0 - d_cycled_A) ** 2)
                d_loss_B_cycled = torch.mean((0 - d_cycled_B) ** 2)
                d_loss_A2_real = torch.mean((1 - d_real_A2) ** 2)
                d_loss_B2_real = torch.mean((1 - d_real_B2) ** 2)
                d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0
                d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0

                # Final Loss for discriminator with the Two Step Adverserial Loss
                d_loss = (d_loss_A + d_loss_B) / 2.0 + \
                    (d_loss_A_2nd + d_loss_B_2nd) / 2.0

                # Backprop for Discriminator
                self.reset_grad()
                d_loss.backward()
                self.discriminator_optimizer.step()

                # Log Iteration
                self.logger.log_iter(
                    loss_dict={'g_loss': g_loss.item(), 'd_loss': d_loss.item()})
                self.logger.end_iter()

                # Adjust learning rates
                if self.logger.global_step > self.decay_after:
                    self.identity_loss_lambda = 0
                    self.adjust_lr_rate(
                        self.generator_optimizer, generator=True)
                    self.adjust_lr_rate(
                        self.generator_optimizer, generator=False)

            # Log intermediate outputs on Tensorboard
            if self.logger.epoch % self.epochs_per_plot == 0:
                # Log Mel-spectrograms .png
                real_mel_A_fig = get_mel_spectrogram_fig(
                    real_A[0].detach().cpu())
                fake_mel_A_fig = get_mel_spectrogram_fig(
                    generated_A[0].detach().cpu())
                real_mel_B_fig = get_mel_spectrogram_fig(
                    real_B[0].detach().cpu())
                fake_mel_B_fig = get_mel_spectrogram_fig(
                    generated_B[0].detach().cpu())
                self.logger.visualize_outputs({"real_voc_spec": real_mel_A_fig, "fake_coraal_spec": fake_mel_B_fig,
                                               "real_coraal_spec": real_mel_B_fig, "fake_voc_spec": fake_mel_A_fig})

                # Convert Mel-spectrograms from validation set to waveform and log to tensorboard
                real_mel_full_A, real_mel_full_B = next(
                    iter(self.validation_dataloader))
                real_mel_full_A = real_mel_full_A.to(
                    self.device, dtype=torch.float)
                real_mel_full_B = real_mel_full_B.to(
                    self.device, dtype=torch.float)
                fake_mel_full_B = self.generator_A2B(
                    real_mel_full_A, torch.ones_like(real_mel_full_A))
                fake_mel_full_A = self.generator_B2A(
                    real_mel_full_B, torch.ones_like(real_mel_full_B))
                real_wav_full_A = decode_melspectrogram(self.vocoder, real_mel_full_A[0].detach(
                ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu()
                fake_wav_full_A = decode_melspectrogram(self.vocoder, fake_mel_full_A[0].detach(
                ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu()
                real_wav_full_B = decode_melspectrogram(self.vocoder, real_mel_full_B[0].detach(
                ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu()
                fake_wav_full_B = decode_melspectrogram(self.vocoder, fake_mel_full_B[0].detach(
                ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu()
                self.logger.log_audio(
                    real_wav_full_A.T, "real_speaker_A_audio", self.sample_rate)
                self.logger.log_audio(
                    fake_wav_full_A.T, "fake_speaker_A_audio", self.sample_rate)
                self.logger.log_audio(
                    real_wav_full_B.T, "real_speaker_B_audio", self.sample_rate)
                self.logger.log_audio(
                    fake_wav_full_B.T, "fake_speaker_B_audio", self.sample_rate)

            # Save each model checkpoint
            if self.logger.epoch % self.epochs_per_save == 0:
                self.saver.save(self.logger.epoch, self.generator_A2B,
                                self.generator_optimizer, None, args.device, "generator_A2B")
                self.saver.save(self.logger.epoch, self.generator_B2A,
                                self.generator_optimizer, None, args.device, "generator_B2A")
                self.saver.save(self.logger.epoch, self.discriminator_A,
                                self.discriminator_optimizer, None, args.device, "discriminator_A")
                self.saver.save(self.logger.epoch, self.discriminator_B,
                                self.discriminator_optimizer, None, args.device, "discriminator_B")
                self.saver.save(self.logger.epoch, self.discriminator_A2,
                                self.discriminator_optimizer, None, args.device, "discriminator_A2")
                self.saver.save(self.logger.epoch, self.discriminator_B2,
                                self.discriminator_optimizer, None, args.device, "discriminator_B2")

            self.logger.end_epoch()