Пример #1
0
    def __init__(self, args):
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.mini_batch_size = args.batch_size
        self.device = args.device

        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        self.data_dir = args.data_dir
        self.source_id = args.source_id
        self.save_dir = args.save_dir
        self.saver = ModelSaver(args)

        # Generator
        self.generator_A2B = Generator().to(self.device)

        # Load from previous ckpt
        self.saver.load_model(self.generator_A2B, "generator_A2B",
                              args.ckpt_path, None, None)

        voc_wav_files = self.read_manifest(dataset="voc",
                                           speaker_id=self.source_id)
        print(f'Found {len(voc_wav_files)} wav files')
        self.dataset_A, self.dataset_A_mean, self.dataset_A_std = self.normalize_mel(
            voc_wav_files, self.data_dir, sr=self.sample_rate)
        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')
Пример #2
0
    def __init__(self, args):
        """
        Args:
            args (Namespace): Program arguments from argparser
        """
        # Store Args
        self.device = args.device
        self.converted_audio_dir = os.path.join(args.save_dir, args.name,
                                                'converted_audio')
        os.makedirs(self.converted_audio_dir, exist_ok=True)
        self.model_name = args.model_name

        self.speaker_A_id = args.speaker_A_id
        self.speaker_B_id = args.speaker_B_id
        # Initialize MelGAN-Vocoder used to decode Mel-spectrograms
        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        # Initialize speakerA's dataset
        self.dataset_A = self.loadPickleFile(
            os.path.join(args.preprocessed_data_dir, self.speaker_A_id,
                         f"{self.speaker_A_id}_normalized.pickle"))
        dataset_A_norm_stats = np.load(
            os.path.join(args.preprocessed_data_dir, self.speaker_A_id,
                         f"{self.speaker_A_id}_norm_stat.npz"))
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']

        # Initialize speakerB's dataset
        self.dataset_B = self.loadPickleFile(
            os.path.join(args.preprocessed_data_dir, self.speaker_B_id,
                         f"{self.speaker_B_id}_normalized.pickle"))
        dataset_B_norm_stats = np.load(
            os.path.join(args.preprocessed_data_dir, self.speaker_B_id,
                         f"{self.speaker_B_id}_norm_stat.npz"))
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        source_dataset = self.dataset_A if self.model_name == 'generator_A2B' else self.dataset_B
        self.dataset = VCDataset(datasetA=source_dataset,
                                 datasetB=None,
                                 valid=True)
        self.test_dataloader = torch.utils.data.DataLoader(
            dataset=self.dataset, batch_size=1, shuffle=False, drop_last=False)

        # Generator
        self.generator_A2B = Generator().to(self.device)

        # Load Generator from ckpt
        self.saver = ModelSaver(args)
        self.saver.load_model(self.generator_A2B, self.model_name)
Пример #3
0
    def __init__(self, args):
        """
        Args:
            args (Namespace): Program arguments from argparser
        """
        # Store args
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.generator_lr = args.generator_lr
        self.discriminator_lr = args.discriminator_lr
        self.decay_after = args.decay_after
        self.mini_batch_size = args.batch_size
        self.cycle_loss_lambda = args.cycle_loss_lambda
        self.identity_loss_lambda = args.identity_loss_lambda
        self.device = args.device
        self.epochs_per_save = args.epochs_per_save
        self.epochs_per_plot = args.epochs_per_plot

        # Initialize MelGAN-Vocoder used to decode Mel-spectrograms
        self.vocoder = torch.hub.load(
            'descriptinc/melgan-neurips', 'load_melgan')
        self.sample_rate = args.sample_rate

        # Initialize speakerA's dataset
        self.dataset_A = self.loadPickleFile(os.path.join(
            args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle"))
        dataset_A_norm_stats = np.load(os.path.join(
            args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz"))
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']

        # Initialize speakerB's dataset
        self.dataset_B = self.loadPickleFile(os.path.join(
            args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle"))
        dataset_B_norm_stats = np.load(os.path.join(
            args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz"))
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        # Compute lr decay rate
        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')
        self.generator_lr_decay = self.generator_lr / \
            float(self.num_epochs * (self.n_samples // self.mini_batch_size))
        self.discriminator_lr_decay = self.discriminator_lr / \
            float(self.num_epochs * (self.n_samples // self.mini_batch_size))
        print(f'generator_lr_decay = {self.generator_lr_decay}')
        print(f'discriminator_lr_decay = {self.discriminator_lr_decay}')

        # Initialize Train Dataloader
        self.num_frames = args.num_frames
        self.dataset = VCDataset(datasetA=self.dataset_A,
                                 datasetB=self.dataset_B,
                                 n_frames=args.num_frames,
                                 max_mask_len=args.max_mask_len)
        self.train_dataloader = torch.utils.data.DataLoader(dataset=self.dataset,
                                                            batch_size=self.mini_batch_size,
                                                            shuffle=True,
                                                            drop_last=False)

        # Initialize Validation Dataloader (used to generate intermediate outputs)
        self.validation_dataset = VCDataset(datasetA=self.dataset_A,
                                            datasetB=self.dataset_B,
                                            n_frames=args.num_frames_validation,
                                            max_mask_len=args.max_mask_len,
                                            valid=True)
        self.validation_dataloader = torch.utils.data.DataLoader(dataset=self.validation_dataset,
                                                                 batch_size=1,
                                                                 shuffle=False,
                                                                 drop_last=False)

        # Initialize logger and saver objects
        self.logger = TrainLogger(args, len(self.train_dataloader.dataset))
        self.saver = ModelSaver(args)

        # Initialize Generators and Discriminators
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_A2 = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_B2 = Discriminator().to(self.device)

        # Initialize Optimizers
        g_params = list(self.generator_A2B.parameters()) + \
            list(self.generator_B2A.parameters())
        d_params = list(self.discriminator_A.parameters()) + \
            list(self.discriminator_B.parameters()) + \
            list(self.discriminator_A2.parameters()) + \
            list(self.discriminator_B2.parameters())
        self.generator_optimizer = torch.optim.Adam(
            g_params, lr=self.generator_lr, betas=(0.5, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

        # Load from previous ckpt
        if args.continue_train:
            self.saver.load_model(
                self.generator_A2B, "generator_A2B", None, self.generator_optimizer)
            self.saver.load_model(self.generator_B2A,
                                  "generator_B2A", None, None)
            self.saver.load_model(self.discriminator_A,
                                  "discriminator_A", None, self.discriminator_optimizer)
            self.saver.load_model(self.discriminator_B,
                                  "discriminator_B", None, None)
            self.saver.load_model(self.discriminator_A2,
                                  "discriminator_A2", None, None)
            self.saver.load_model(self.discriminator_B2,
                                  "discriminator_B2", None, None)
Пример #4
0
class MaskCycleGANVCTraining(object):
    """Trainer for MaskCycleGAN-VC
    """

    def __init__(self, args):
        """
        Args:
            args (Namespace): Program arguments from argparser
        """
        # Store args
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.generator_lr = args.generator_lr
        self.discriminator_lr = args.discriminator_lr
        self.decay_after = args.decay_after
        self.mini_batch_size = args.batch_size
        self.cycle_loss_lambda = args.cycle_loss_lambda
        self.identity_loss_lambda = args.identity_loss_lambda
        self.device = args.device
        self.epochs_per_save = args.epochs_per_save
        self.epochs_per_plot = args.epochs_per_plot

        # Initialize MelGAN-Vocoder used to decode Mel-spectrograms
        self.vocoder = torch.hub.load(
            'descriptinc/melgan-neurips', 'load_melgan')
        self.sample_rate = args.sample_rate

        # Initialize speakerA's dataset
        self.dataset_A = self.loadPickleFile(os.path.join(
            args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_normalized.pickle"))
        dataset_A_norm_stats = np.load(os.path.join(
            args.preprocessed_data_dir, args.speaker_A_id, f"{args.speaker_A_id}_norm_stat.npz"))
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']

        # Initialize speakerB's dataset
        self.dataset_B = self.loadPickleFile(os.path.join(
            args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_normalized.pickle"))
        dataset_B_norm_stats = np.load(os.path.join(
            args.preprocessed_data_dir, args.speaker_B_id, f"{args.speaker_B_id}_norm_stat.npz"))
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        # Compute lr decay rate
        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')
        self.generator_lr_decay = self.generator_lr / \
            float(self.num_epochs * (self.n_samples // self.mini_batch_size))
        self.discriminator_lr_decay = self.discriminator_lr / \
            float(self.num_epochs * (self.n_samples // self.mini_batch_size))
        print(f'generator_lr_decay = {self.generator_lr_decay}')
        print(f'discriminator_lr_decay = {self.discriminator_lr_decay}')

        # Initialize Train Dataloader
        self.num_frames = args.num_frames
        self.dataset = VCDataset(datasetA=self.dataset_A,
                                 datasetB=self.dataset_B,
                                 n_frames=args.num_frames,
                                 max_mask_len=args.max_mask_len)
        self.train_dataloader = torch.utils.data.DataLoader(dataset=self.dataset,
                                                            batch_size=self.mini_batch_size,
                                                            shuffle=True,
                                                            drop_last=False)

        # Initialize Validation Dataloader (used to generate intermediate outputs)
        self.validation_dataset = VCDataset(datasetA=self.dataset_A,
                                            datasetB=self.dataset_B,
                                            n_frames=args.num_frames_validation,
                                            max_mask_len=args.max_mask_len,
                                            valid=True)
        self.validation_dataloader = torch.utils.data.DataLoader(dataset=self.validation_dataset,
                                                                 batch_size=1,
                                                                 shuffle=False,
                                                                 drop_last=False)

        # Initialize logger and saver objects
        self.logger = TrainLogger(args, len(self.train_dataloader.dataset))
        self.saver = ModelSaver(args)

        # Initialize Generators and Discriminators
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_A2 = Discriminator().to(self.device)
        # Discriminator to compute 2 step adversarial loss
        self.discriminator_B2 = Discriminator().to(self.device)

        # Initialize Optimizers
        g_params = list(self.generator_A2B.parameters()) + \
            list(self.generator_B2A.parameters())
        d_params = list(self.discriminator_A.parameters()) + \
            list(self.discriminator_B.parameters()) + \
            list(self.discriminator_A2.parameters()) + \
            list(self.discriminator_B2.parameters())
        self.generator_optimizer = torch.optim.Adam(
            g_params, lr=self.generator_lr, betas=(0.5, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

        # Load from previous ckpt
        if args.continue_train:
            self.saver.load_model(
                self.generator_A2B, "generator_A2B", None, self.generator_optimizer)
            self.saver.load_model(self.generator_B2A,
                                  "generator_B2A", None, None)
            self.saver.load_model(self.discriminator_A,
                                  "discriminator_A", None, self.discriminator_optimizer)
            self.saver.load_model(self.discriminator_B,
                                  "discriminator_B", None, None)
            self.saver.load_model(self.discriminator_A2,
                                  "discriminator_A2", None, None)
            self.saver.load_model(self.discriminator_B2,
                                  "discriminator_B2", None, None)

    def adjust_lr_rate(self, optimizer, generator):
        """Decays learning rate.

        Args:
            optimizer (torch.optim): torch optimizer
            generator (bool): Whether to adjust generator lr.
        """
        if generator:
            self.generator_lr = max(
                0., self.generator_lr - self.generator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups['lr'] = self.generator_lr
        else:
            self.discriminator_lr = max(
                0., self.discriminator_lr - self.discriminator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups['lr'] = self.discriminator_lr

    def reset_grad(self):
        """Sets gradients of the generators and discriminators to zero before backpropagation.
        """
        self.generator_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

    def loadPickleFile(self, fileName):
        """Loads a Pickle file.

        Args:
            fileName (str): pickle file path

        Returns:
            file object: The loaded pickle file object
        """
        with open(fileName, 'rb') as f:
            return pickle.load(f)

    def train(self):
        """Implements the training loop for MaskCycleGAN-VC
        """
        for epoch in range(self.start_epoch, self.num_epochs + 1):
            self.logger.start_epoch()

            for i, (real_A, mask_A, real_B, mask_B) in enumerate(tqdm(self.train_dataloader)):
                self.logger.start_iter()
                num_iterations = (
                    self.n_samples // self.mini_batch_size) * epoch + i

                real_A = real_A.to(self.device, dtype=torch.float)
                mask_A = mask_A.to(self.device, dtype=torch.float)
                real_B = real_B.to(self.device, dtype=torch.float)
                mask_B = mask_B.to(self.device, dtype=torch.float)

                # Train Generator

                # Generator Feed Forward
                fake_B = self.generator_A2B(real_A, mask_A)
                cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B))
                fake_A = self.generator_B2A(real_B, mask_B)
                cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A))
                identity_A = self.generator_B2A(
                    real_A, torch.ones_like(real_A))
                identity_B = self.generator_A2B(
                    real_B, torch.ones_like(real_B))
                d_fake_A = self.discriminator_A(fake_A)
                d_fake_B = self.discriminator_B(fake_B)

                # For Two Step Adverserial Loss
                d_fake_cycle_A = self.discriminator_A2(cycle_A)
                d_fake_cycle_B = self.discriminator_B2(cycle_B)

                # Generator Cycle Loss
                cycleLoss = torch.mean(
                    torch.abs(real_A - cycle_A)) + torch.mean(torch.abs(real_B - cycle_B))

                # Generator Identity Loss
                identityLoss = torch.mean(
                    torch.abs(real_A - identity_A)) + torch.mean(torch.abs(real_B - identity_B))

                # Generator Loss
                g_loss_A2B = torch.mean((1 - d_fake_B) ** 2)
                g_loss_B2A = torch.mean((1 - d_fake_A) ** 2)

                # Generator Two Step Adverserial Loss
                generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B) ** 2)
                generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A) ** 2)

                # Total Generator Loss
                g_loss = g_loss_A2B + g_loss_B2A + \
                    generator_loss_A2B_2nd + generator_loss_B2A_2nd + \
                    self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss

                # Backprop for Generator
                self.reset_grad()
                g_loss.backward()
                self.generator_optimizer.step()

                # Train Discriminator

                # Discriminator Feed Forward
                d_real_A = self.discriminator_A(real_A)
                d_real_B = self.discriminator_B(real_B)
                d_real_A2 = self.discriminator_A2(real_A)
                d_real_B2 = self.discriminator_B2(real_B)
                generated_A = self.generator_B2A(real_B, mask_B)
                d_fake_A = self.discriminator_A(generated_A)

                # For Two Step Adverserial Loss A->B
                cycled_B = self.generator_A2B(
                    generated_A, torch.ones_like(generated_A))
                d_cycled_B = self.discriminator_B2(cycled_B)

                generated_B = self.generator_A2B(real_A, mask_A)
                d_fake_B = self.discriminator_B(generated_B)

                # For Two Step Adverserial Loss B->A
                cycled_A = self.generator_B2A(
                    generated_B, torch.ones_like(generated_B))
                d_cycled_A = self.discriminator_A2(cycled_A)

                # Loss Functions
                d_loss_A_real = torch.mean((1 - d_real_A) ** 2)
                d_loss_A_fake = torch.mean((0 - d_fake_A) ** 2)
                d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0

                d_loss_B_real = torch.mean((1 - d_real_B) ** 2)
                d_loss_B_fake = torch.mean((0 - d_fake_B) ** 2)
                d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0

                # Two Step Adverserial Loss
                d_loss_A_cycled = torch.mean((0 - d_cycled_A) ** 2)
                d_loss_B_cycled = torch.mean((0 - d_cycled_B) ** 2)
                d_loss_A2_real = torch.mean((1 - d_real_A2) ** 2)
                d_loss_B2_real = torch.mean((1 - d_real_B2) ** 2)
                d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0
                d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0

                # Final Loss for discriminator with the Two Step Adverserial Loss
                d_loss = (d_loss_A + d_loss_B) / 2.0 + \
                    (d_loss_A_2nd + d_loss_B_2nd) / 2.0

                # Backprop for Discriminator
                self.reset_grad()
                d_loss.backward()
                self.discriminator_optimizer.step()

                # Log Iteration
                self.logger.log_iter(
                    loss_dict={'g_loss': g_loss.item(), 'd_loss': d_loss.item()})
                self.logger.end_iter()

                # Adjust learning rates
                if self.logger.global_step > self.decay_after:
                    self.identity_loss_lambda = 0
                    self.adjust_lr_rate(
                        self.generator_optimizer, generator=True)
                    self.adjust_lr_rate(
                        self.generator_optimizer, generator=False)

            # Log intermediate outputs on Tensorboard
            if self.logger.epoch % self.epochs_per_plot == 0:
                # Log Mel-spectrograms .png
                real_mel_A_fig = get_mel_spectrogram_fig(
                    real_A[0].detach().cpu())
                fake_mel_A_fig = get_mel_spectrogram_fig(
                    generated_A[0].detach().cpu())
                real_mel_B_fig = get_mel_spectrogram_fig(
                    real_B[0].detach().cpu())
                fake_mel_B_fig = get_mel_spectrogram_fig(
                    generated_B[0].detach().cpu())
                self.logger.visualize_outputs({"real_voc_spec": real_mel_A_fig, "fake_coraal_spec": fake_mel_B_fig,
                                               "real_coraal_spec": real_mel_B_fig, "fake_voc_spec": fake_mel_A_fig})

                # Convert Mel-spectrograms from validation set to waveform and log to tensorboard
                real_mel_full_A, real_mel_full_B = next(
                    iter(self.validation_dataloader))
                real_mel_full_A = real_mel_full_A.to(
                    self.device, dtype=torch.float)
                real_mel_full_B = real_mel_full_B.to(
                    self.device, dtype=torch.float)
                fake_mel_full_B = self.generator_A2B(
                    real_mel_full_A, torch.ones_like(real_mel_full_A))
                fake_mel_full_A = self.generator_B2A(
                    real_mel_full_B, torch.ones_like(real_mel_full_B))
                real_wav_full_A = decode_melspectrogram(self.vocoder, real_mel_full_A[0].detach(
                ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu()
                fake_wav_full_A = decode_melspectrogram(self.vocoder, fake_mel_full_A[0].detach(
                ).cpu(), self.dataset_A_mean, self.dataset_A_std).cpu()
                real_wav_full_B = decode_melspectrogram(self.vocoder, real_mel_full_B[0].detach(
                ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu()
                fake_wav_full_B = decode_melspectrogram(self.vocoder, fake_mel_full_B[0].detach(
                ).cpu(), self.dataset_B_mean, self.dataset_B_std).cpu()
                self.logger.log_audio(
                    real_wav_full_A.T, "real_speaker_A_audio", self.sample_rate)
                self.logger.log_audio(
                    fake_wav_full_A.T, "fake_speaker_A_audio", self.sample_rate)
                self.logger.log_audio(
                    real_wav_full_B.T, "real_speaker_B_audio", self.sample_rate)
                self.logger.log_audio(
                    fake_wav_full_B.T, "fake_speaker_B_audio", self.sample_rate)

            # Save each model checkpoint
            if self.logger.epoch % self.epochs_per_save == 0:
                self.saver.save(self.logger.epoch, self.generator_A2B,
                                self.generator_optimizer, None, args.device, "generator_A2B")
                self.saver.save(self.logger.epoch, self.generator_B2A,
                                self.generator_optimizer, None, args.device, "generator_B2A")
                self.saver.save(self.logger.epoch, self.discriminator_A,
                                self.discriminator_optimizer, None, args.device, "discriminator_A")
                self.saver.save(self.logger.epoch, self.discriminator_B,
                                self.discriminator_optimizer, None, args.device, "discriminator_B")
                self.saver.save(self.logger.epoch, self.discriminator_A2,
                                self.discriminator_optimizer, None, args.device, "discriminator_A2")
                self.saver.save(self.logger.epoch, self.discriminator_B2,
                                self.discriminator_optimizer, None, args.device, "discriminator_B2")

            self.logger.end_epoch()
Пример #5
0
class CycleGANGenerate(object):
    def __init__(self, args):
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.mini_batch_size = args.batch_size
        self.device = args.device

        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        self.data_dir = args.data_dir
        self.source_id = args.source_id
        self.save_dir = args.save_dir
        self.saver = ModelSaver(args)

        # Generator
        self.generator_A2B = Generator().to(self.device)

        # Load from previous ckpt
        self.saver.load_model(self.generator_A2B, "generator_A2B",
                              args.ckpt_path, None, None)

        voc_wav_files = self.read_manifest(dataset="voc",
                                           speaker_id=self.source_id)
        print(f'Found {len(voc_wav_files)} wav files')
        self.dataset_A, self.dataset_A_mean, self.dataset_A_std = self.normalize_mel(
            voc_wav_files, self.data_dir, sr=self.sample_rate)
        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')

    def read_manifest(self, split=None, dataset=None, speaker_id=None):
        # Load manifest file which defines dataset
        manifest_path = os.path.join('./manifests', f'{dataset}_manifest.csv')
        df = pd.read_csv(manifest_path, sep=',')

        # Filter by speaker_id
        df['speaker_id'] = df['speaker_id'].astype(str)
        df = df[df['speaker_id'] == speaker_id]
        wav_files = df['wav_file'].tolist()

        return wav_files

    def normalize_mel(self, wav_files, data_dir, sr=22050):
        vocoder = torch.hub.load('descriptinc/melgan-neurips', 'load_melgan')

        mel_list = dict()
        for wavpath in tqdm(wav_files, desc='Preprocess wav to mel'):
            wav_orig, _ = librosa.load(os.path.join(data_dir, wavpath),
                                       sr=sr,
                                       mono=True)
            spec = vocoder(torch.tensor([wav_orig]))
            assert wavpath not in mel_list
            mel_list[wavpath] = spec.cpu().detach().numpy()[0]

        mel_concatenated = np.concatenate(list(mel_list.values()), axis=1)
        mel_mean = np.mean(mel_concatenated, axis=1, keepdims=True)
        mel_std = np.std(mel_concatenated, axis=1, keepdims=True) + 1e-9

        mel_normalized = dict()
        for wavpath, mel in mel_list.items():
            app = (mel - mel_mean) / mel_std
            assert wavpath not in mel_normalized
            mel_normalized[wavpath] = app

        return mel_normalized, mel_mean, mel_std

    def save_pickle(self, variable, fileName):
        with open(fileName, 'wb') as f:
            pickle.dump(variable, f)

    def run(self):

        converted_specs = dict()
        for i, (wavpath, melspec) in enumerate(tqdm(self.dataset_A.items())):
            real_A = torch.tensor(melspec).unsqueeze(0).to(self.device,
                                                           dtype=torch.float)
            fake_B_normalized = self.generator_A2B(
                real_A,
                torch.ones_like(real_A)).squeeze(0).detach().cpu().numpy()
            fake_B = fake_B_normalized * self.dataset_A_std + self.dataset_A_mean
            converted_specs[wavpath] = fake_B

        print(
            f"Saving to ~/data/converted/voc_converted_{self.source_id}.pickle"
        )
        self.save_pickle(variable=converted_specs,
                         fileName=os.path.join(
                             '/home/ubuntu/data', "converted",
                             f"voc_converted_{self.source_id}.pickle"))
Пример #6
0
def SaveSubgraph(option, subg):
    saver = ModelSaver(subg)

    if option.save_config == True:
        saver.SaveConfigInfo(option.save_prefix)
Пример #7
0
class MaskCycleGANVCTesting(object):
    """Tester for MaskCycleGAN-VC
    """
    def __init__(self, args):
        """
        Args:
            args (Namespace): Program arguments from argparser
        """
        # Store Args
        self.device = args.device
        self.converted_audio_dir = os.path.join(args.save_dir, args.name,
                                                'converted_audio')
        os.makedirs(self.converted_audio_dir, exist_ok=True)
        self.model_name = args.model_name

        self.speaker_A_id = args.speaker_A_id
        self.speaker_B_id = args.speaker_B_id
        # Initialize MelGAN-Vocoder used to decode Mel-spectrograms
        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        # Initialize speakerA's dataset
        self.dataset_A = self.loadPickleFile(
            os.path.join(args.preprocessed_data_dir, self.speaker_A_id,
                         f"{self.speaker_A_id}_normalized.pickle"))
        dataset_A_norm_stats = np.load(
            os.path.join(args.preprocessed_data_dir, self.speaker_A_id,
                         f"{self.speaker_A_id}_norm_stat.npz"))
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']

        # Initialize speakerB's dataset
        self.dataset_B = self.loadPickleFile(
            os.path.join(args.preprocessed_data_dir, self.speaker_B_id,
                         f"{self.speaker_B_id}_normalized.pickle"))
        dataset_B_norm_stats = np.load(
            os.path.join(args.preprocessed_data_dir, self.speaker_B_id,
                         f"{self.speaker_B_id}_norm_stat.npz"))
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        source_dataset = self.dataset_A if self.model_name == 'generator_A2B' else self.dataset_B
        self.dataset = VCDataset(datasetA=source_dataset,
                                 datasetB=None,
                                 valid=True)
        self.test_dataloader = torch.utils.data.DataLoader(
            dataset=self.dataset, batch_size=1, shuffle=False, drop_last=False)

        # Generator
        self.generator_A2B = Generator().to(self.device)

        # Load Generator from ckpt
        self.saver = ModelSaver(args)
        self.saver.load_model(self.generator_A2B, self.model_name)

    def loadPickleFile(self, fileName):
        """Loads a Pickle file.

        Args:
            fileName (str): pickle file path

        Returns:
            file object: The loaded pickle file object
        """
        with open(fileName, 'rb') as f:
            return pickle.load(f)

    def test(self):
        for i, (real_A) in enumerate(tqdm(self.test_dataloader)):
            real_A = real_A.to(self.device, dtype=torch.float)
            fake_B = self.generator_A2B(real_A, torch.ones_like(real_A))
            wav_fake_B = decode_melspectrogram(self.vocoder,
                                               fake_B[0].detach().cpu(),
                                               self.dataset_A_mean,
                                               self.dataset_A_std).cpu()
            save_path = None
            if self.model_name == 'generator_A2B':
                save_path = os.path.join(
                    self.converted_audio_dir,
                    f"converted_{self.speaker_A_id}_to_{self.speaker_B_id}{i}.wav"
                )
            else:
                save_path = os.path.join(
                    self.converted_audio_dir,
                    f"converted_{self.speaker_B_id}_to_{self.speaker_A_id}{i}.wav"
                )
            torchaudio.save(save_path,
                            wav_fake_B,
                            sample_rate=self.sample_rate)
def train(args):
    """
    Implements the training loop for the MultiTaskResnet3dClassifier.
    Args:
        args (Namespace) : Program arguments
    """
    # Get model and loss function
    model = MTClassifier3D(args).to(args.device)

    # Initialize losses for each head
    loss_wrapper = MultiTaskLoss(args)
    loss_fn = nn.BCEWithLogitsLoss()

    # TODO: Get train and validation dataloaders
    train_dataset = ClassifierDataset(args.csv_dir, 'train', args.features, resample=(
        args.num_slices, args.slice_size, args.slice_size))
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, pin_memory=True
    )
    
    peds_validation_dataset = ClassifierDataset(args.peds_csv_dir, 'val', args.peds_features, resample=(
        args.num_slices, args.slice_size, args.slice_size))
    peds_validation_loader = DataLoader(
        peds_validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True
    )
    
    adult_validation_dataset = ClassifierDataset(args.adult_csv_dir, 'val', args.adult_features, resample=(
        args.num_slices, args.slice_size, args.slice_size))
    adult_validation_loader = DataLoader(
        adult_validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, pin_memory=True
    )

    # Get optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), args.lr)
    warmup_iters = args.lr_warmup_epochs * len(train_loader)
    lr_milestones = [len(train_loader) * m for m in args.lr_milestones]
    lr_scheduler = WarmupMultiStepLR(
        optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
        warmup_iters=warmup_iters, warmup_factor=1e-5)

    # Get saver, logger, and evaluator
    saver = ModelSaver(args, max_ckpts=args.max_ckpts,
                       metric_name=args.best_ckpt_metric, maximize_metric=args.maximize_metric)
    # evaluator = ModelEvaluator(args, validation_loader, cls_loss_fn)

    # Load model from checkpoint is applicable
    if args.continue_train:
        saver.load_model(model, args.name, ckpt_path=args.load_path,
                         optimizer=optimizer, scheduler=lr_scheduler)
    logger = TrainLogger(args, len(train_loader.dataset))


    # Multi GPU training if applicable
    if len(args.gpu_ids) > 1:
        print("Using", len(args.gpu_ids), "GPUs.")
        model = nn.DataParallel(model)

    loss_meter = meter.AverageValueMeter()

    # Train model
    logger.log_hparams(args)
    while not logger.is_finished_training():
        logger.start_epoch()

        for inputs, targets in tqdm(train_loader):
            logger.start_iter()
            with torch.set_grad_enabled(True):
                inputs = inputs.to(args.device)
                targets = targets.to(args.device)
                head_preds = model(inputs)

                loss = loss_wrapper(head_preds, targets)
                loss_meter.add(loss.item())

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Log all train losses
            if logger.iter % args.steps_per_print == 0 and logger.iter != 0:
                logger.log_metrics({'train_loss': loss_meter.value()[0]})
                loss_meter.reset()

            logger.end_iter()

        # Evaluate model and save model ckpt
        if logger.epoch % args.epochs_per_eval == 0:
            peds_metrics = evaluate(args, model, loss_wrapper,
                            peds_validation_loader, "validation", args.device, 'peds')
            logger.log_metrics(peds_metrics)
            adult_metrics = evaluate(args, model, loss_wrapper,
                            adult_validation_loader, "validation", args.device, 'adult')
            logger.log_metrics(adult_metrics)
        
        if logger.epoch % args.epochs_per_save == 0:
            saver.save(logger.epoch, model, optimizer, lr_scheduler, args.device,
                       args.name)
        lr_scheduler.step()
        logger.end_epoch()
Пример #9
0
def main(args):
    if args.librispeech:
        print("Loading Librispeech dataset!")
        train_dataset = torchaudio.datasets.LIBRISPEECH(
            args.data_dir, url="train-clean-360", download=True)
        valid_dataset = torchaudio.datasets.LIBRISPEECH(
            args.data_dir, url="test-clean", download=True)
    else:
        train_dataset = Dataset(args, "train", return_pair=args.return_pair)
        valid_dataset = Dataset(args, "val", return_pair=args.return_pair)

    print(f"Training set has {len(train_dataset)} samples. Validation set has {len(valid_dataset)} samples.")
        # train_audio_transforms = get_audio_transforms('train')
        # valid_audio_transforms = get_audio_transforms('valid')

    text_transform = TextTransform()

    train_loader = data.DataLoader(dataset=train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   collate_fn=lambda x: data_processing(
                                       x, "train", text_transform),
                                   num_workers=args.num_workers,
                                   pin_memory=True)
    valid_loader = data.DataLoader(dataset=valid_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   collate_fn=lambda x: data_processing(
                                       x, "valid", text_transform),
                                   num_workers=args.num_workers,
                                   pin_memory=True)

    model = SpeechRecognitionModel(
        args.n_cnn_layers, args.n_rnn_layers, args.rnn_dim,
        args.n_class, args.n_feats, args.stride, args.dropout
    ).to(args.device)

    print('Num Model Parameters', sum(
        [param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), args.lr)
    criterion = nn.CTCLoss(blank=28).to(args.device)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr,
                                              steps_per_epoch=int(
                                                  len(train_loader)),
                                              epochs=args.num_epochs,
                                              anneal_strategy='linear')
    # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.gamma)

    saver = ModelSaver(args, max_ckpts=args.max_ckpts,
                       metric_name="test_wer", maximize_metric=False)

    if args.continue_train:
        saver.load_model(model, "SpeechRecognitionModel",
                         args.ckpt_path, optimizer, scheduler)
    elif args.pretrained_ckpt_path:
        saver.load_model(model, "SpeechRecognitionModel",
                         args.pretrained_ckpt_path, None, None)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model)

    logger = TrainLogger(args, len(train_loader.dataset))
    logger.log_hparams(args)

    for epoch in range(args.start_epoch, args.num_epochs + 1):
        train(args, model, train_loader, criterion,
              optimizer, scheduler, logger)
        if logger.epoch % args.epochs_per_save == 0:
            metric_dict = test(args, model, valid_loader, criterion, logger)
            saver.save(logger.epoch, model, optimizer, scheduler, args.device,
                       "SpeechRecognitionModel", metric_dict["test_wer"])
        logger.end_epoch()
Пример #10
0
    def __init__(self, args):
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.generator_lr = args.generator_lr
        self.discriminator_lr = args.discriminator_lr
        self.decay_after = args.decay_after
        self.mini_batch_size = args.batch_size
        self.cycle_loss_lambda = args.cycle_loss_lambda
        self.identity_loss_lambda = args.identity_loss_lambda
        self.device = args.device
        self.epochs_per_save = args.epochs_per_save
        self.epochs_per_plot = args.epochs_per_plot

        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        self.dataset_A = self.loadPickleFile(args.normalized_dataset_A_path)
        dataset_A_norm_stats = np.load(args.norm_stats_A_path)
        # TODO: fix to mean and std after running data preprocessing script again
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']
        self.dataset_B = self.loadPickleFile(args.normalized_dataset_B_path)
        dataset_B_norm_stats = np.load(args.norm_stats_B_path)
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')
        self.generator_lr_decay = self.generator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        self.discriminator_lr_decay = self.discriminator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        print(f'generator_lr_decay = {self.generator_lr_decay}')
        print(f'discriminator_lr_decay = {self.discriminator_lr_decay}')
        self.num_frames = args.num_frames
        self.dataset = trainingDataset(datasetA=self.dataset_A,
                                       datasetB=self.dataset_B,
                                       n_frames=args.num_frames,
                                       max_mask_len=args.max_mask_len)
        self.train_dataloader = torch.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=self.mini_batch_size,
            shuffle=True,
            drop_last=False)

        self.validation_dataset = trainingDataset(
            datasetA=self.dataset_A,
            datasetB=self.dataset_B,
            n_frames=args.num_frames_validation,
            max_mask_len=args.max_mask_len,
            valid=True)
        self.validation_dataloader = torch.utils.data.DataLoader(
            dataset=self.validation_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False)

        self.logger = TrainLogger(args, len(self.train_dataloader.dataset))
        self.saver = ModelSaver(args)

        # Generator and Discriminator
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)
        self.discriminator_A2 = Discriminator().to(self.device)
        self.discriminator_B2 = Discriminator().to(self.device)

        # Optimizer
        g_params = list(self.generator_A2B.parameters()) + \
            list(self.generator_B2A.parameters())
        d_params = list(self.discriminator_A.parameters()) + \
            list(self.discriminator_B.parameters()) + \
            list(self.discriminator_A2.parameters()) + \
            list(self.discriminator_B2.parameters())

        self.generator_optimizer = torch.optim.Adam(g_params,
                                                    lr=self.generator_lr,
                                                    betas=(0.5, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

        # Load from previous ckpt
        if args.continue_train:
            self.saver.load_model(self.generator_A2B, "generator_A2B", None,
                                  self.generator_optimizer)
            self.saver.load_model(self.generator_B2A, "generator_B2A", None,
                                  None)
            self.saver.load_model(self.discriminator_A, "discriminator_A",
                                  None, self.discriminator_optimizer)
            self.saver.load_model(self.discriminator_B, "discriminator_B",
                                  None, None)
            self.saver.load_model(self.discriminator_A2, "discriminator_A2",
                                  None, None)
            self.saver.load_model(self.discriminator_B2, "discriminator_B2",
                                  None, None)
Пример #11
0
class CycleGANTraining(object):
    def __init__(self, args):
        self.num_epochs = args.num_epochs
        self.start_epoch = args.start_epoch
        self.generator_lr = args.generator_lr
        self.discriminator_lr = args.discriminator_lr
        self.decay_after = args.decay_after
        self.mini_batch_size = args.batch_size
        self.cycle_loss_lambda = args.cycle_loss_lambda
        self.identity_loss_lambda = args.identity_loss_lambda
        self.device = args.device
        self.epochs_per_save = args.epochs_per_save
        self.epochs_per_plot = args.epochs_per_plot

        self.vocoder = torch.hub.load('descriptinc/melgan-neurips',
                                      'load_melgan')
        self.sample_rate = args.sample_rate

        self.dataset_A = self.loadPickleFile(args.normalized_dataset_A_path)
        dataset_A_norm_stats = np.load(args.norm_stats_A_path)
        # TODO: fix to mean and std after running data preprocessing script again
        self.dataset_A_mean = dataset_A_norm_stats['mean']
        self.dataset_A_std = dataset_A_norm_stats['std']
        self.dataset_B = self.loadPickleFile(args.normalized_dataset_B_path)
        dataset_B_norm_stats = np.load(args.norm_stats_B_path)
        self.dataset_B_mean = dataset_B_norm_stats['mean']
        self.dataset_B_std = dataset_B_norm_stats['std']

        self.n_samples = len(self.dataset_A)
        print(f'n_samples = {self.n_samples}')
        self.generator_lr_decay = self.generator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        self.discriminator_lr_decay = self.discriminator_lr / float(
            self.num_epochs * (self.n_samples // self.mini_batch_size))
        print(f'generator_lr_decay = {self.generator_lr_decay}')
        print(f'discriminator_lr_decay = {self.discriminator_lr_decay}')
        self.num_frames = args.num_frames
        self.dataset = trainingDataset(datasetA=self.dataset_A,
                                       datasetB=self.dataset_B,
                                       n_frames=args.num_frames,
                                       max_mask_len=args.max_mask_len)
        self.train_dataloader = torch.utils.data.DataLoader(
            dataset=self.dataset,
            batch_size=self.mini_batch_size,
            shuffle=True,
            drop_last=False)

        self.validation_dataset = trainingDataset(
            datasetA=self.dataset_A,
            datasetB=self.dataset_B,
            n_frames=args.num_frames_validation,
            max_mask_len=args.max_mask_len,
            valid=True)
        self.validation_dataloader = torch.utils.data.DataLoader(
            dataset=self.validation_dataset,
            batch_size=1,
            shuffle=False,
            drop_last=False)

        self.logger = TrainLogger(args, len(self.train_dataloader.dataset))
        self.saver = ModelSaver(args)

        # Generator and Discriminator
        self.generator_A2B = Generator().to(self.device)
        self.generator_B2A = Generator().to(self.device)
        self.discriminator_A = Discriminator().to(self.device)
        self.discriminator_B = Discriminator().to(self.device)
        self.discriminator_A2 = Discriminator().to(self.device)
        self.discriminator_B2 = Discriminator().to(self.device)

        # Optimizer
        g_params = list(self.generator_A2B.parameters()) + \
            list(self.generator_B2A.parameters())
        d_params = list(self.discriminator_A.parameters()) + \
            list(self.discriminator_B.parameters()) + \
            list(self.discriminator_A2.parameters()) + \
            list(self.discriminator_B2.parameters())

        self.generator_optimizer = torch.optim.Adam(g_params,
                                                    lr=self.generator_lr,
                                                    betas=(0.5, 0.999))
        self.discriminator_optimizer = torch.optim.Adam(
            d_params, lr=self.discriminator_lr, betas=(0.5, 0.999))

        # Load from previous ckpt
        if args.continue_train:
            self.saver.load_model(self.generator_A2B, "generator_A2B", None,
                                  self.generator_optimizer)
            self.saver.load_model(self.generator_B2A, "generator_B2A", None,
                                  None)
            self.saver.load_model(self.discriminator_A, "discriminator_A",
                                  None, self.discriminator_optimizer)
            self.saver.load_model(self.discriminator_B, "discriminator_B",
                                  None, None)
            self.saver.load_model(self.discriminator_A2, "discriminator_A2",
                                  None, None)
            self.saver.load_model(self.discriminator_B2, "discriminator_B2",
                                  None, None)

    def adjust_lr_rate(self, optimizer, name='generator'):
        if name == 'generator':
            self.generator_lr = max(
                0., self.generator_lr - self.generator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups['lr'] = self.generator_lr
        else:
            self.discriminator_lr = max(
                0., self.discriminator_lr - self.discriminator_lr_decay)
            for param_groups in optimizer.param_groups:
                param_groups['lr'] = self.discriminator_lr

    def reset_grad(self):
        self.generator_optimizer.zero_grad()
        self.discriminator_optimizer.zero_grad()

    def loadPickleFile(self, fileName):
        with open(fileName, 'rb') as f:
            return pickle.load(f)

    def train(self):
        for epoch in range(self.start_epoch, self.num_epochs):
            self.logger.start_epoch()

            for i, (real_A, mask_A, real_B,
                    mask_B) in enumerate(tqdm(self.train_dataloader)):
                self.logger.start_iter()
                num_iterations = (self.n_samples //
                                  self.mini_batch_size) * epoch + i

                real_A = real_A.to(self.device, dtype=torch.float)
                mask_A = mask_A.to(self.device, dtype=torch.float)
                real_B = real_B.to(self.device, dtype=torch.float)
                mask_B = mask_B.to(self.device, dtype=torch.float)

                # Train Generator
                fake_B = self.generator_A2B(real_A, mask_A)
                cycle_A = self.generator_B2A(fake_B, torch.ones_like(fake_B))
                fake_A = self.generator_B2A(real_B, mask_B)
                cycle_B = self.generator_A2B(fake_A, torch.ones_like(fake_A))
                identity_A = self.generator_B2A(real_A,
                                                torch.ones_like(real_A))
                identity_B = self.generator_A2B(real_B,
                                                torch.ones_like(real_B))
                d_fake_A = self.discriminator_A(fake_A)
                d_fake_B = self.discriminator_B(fake_B)

                # for the second step adverserial loss
                d_fake_cycle_A = self.discriminator_A2(cycle_A)
                d_fake_cycle_B = self.discriminator_B2(cycle_B)

                # Generator Cycle Loss
                cycleLoss = torch.mean(
                    torch.abs(real_A - cycle_A)) + torch.mean(
                        torch.abs(real_B - cycle_B))

                # Generator Identity Loss
                identityLoss = torch.mean(
                    torch.abs(real_A - identity_A)) + torch.mean(
                        torch.abs(real_B - identity_B))

                # Generator Loss
                g_loss_A2B = torch.mean((1 - d_fake_B)**2)
                g_loss_B2A = torch.mean((1 - d_fake_A)**2)

                # Generator second step adverserial loss
                generator_loss_A2B_2nd = torch.mean((1 - d_fake_cycle_B)**2)
                generator_loss_B2A_2nd = torch.mean((1 - d_fake_cycle_A)**2)

                # Total Generator Loss
                g_loss = g_loss_A2B + g_loss_B2A + \
                    generator_loss_A2B_2nd + generator_loss_B2A_2nd + \
                    self.cycle_loss_lambda * cycleLoss + self.identity_loss_lambda * identityLoss
                # self.generator_loss_store.append(generator_loss.item())

                # Backprop for Generator
                self.reset_grad()
                g_loss.backward()
                self.generator_optimizer.step()

                # Train Discriminator

                # Discriminator Feed Forward
                d_real_A = self.discriminator_A(real_A)
                d_real_B = self.discriminator_B(real_B)

                d_real_A2 = self.discriminator_A2(real_A)
                d_real_B2 = self.discriminator_B2(real_B)

                generated_A = self.generator_B2A(real_B, mask_B)
                d_fake_A = self.discriminator_A(generated_A)

                # For Second Step Adverserial Loss A->B
                cycled_B = self.generator_A2B(generated_A,
                                              torch.ones_like(generated_A))
                d_cycled_B = self.discriminator_B2(cycled_B)

                generated_B = self.generator_A2B(real_A, mask_A)
                d_fake_B = self.discriminator_B(generated_B)

                # For Second Step Adverserial Loss B->A
                cycled_A = self.generator_B2A(generated_B,
                                              torch.ones_like(generated_B))
                d_cycled_A = self.discriminator_A2(cycled_A)

                # Loss Functions
                d_loss_A_real = torch.mean((1 - d_real_A)**2)
                d_loss_A_fake = torch.mean((0 - d_fake_A)**2)
                d_loss_A = (d_loss_A_real + d_loss_A_fake) / 2.0

                d_loss_B_real = torch.mean((1 - d_real_B)**2)
                d_loss_B_fake = torch.mean((0 - d_fake_B)**2)
                d_loss_B = (d_loss_B_real + d_loss_B_fake) / 2.0

                # Second Step Adverserial Loss
                d_loss_A_cycled = torch.mean((0 - d_cycled_A)**2)
                d_loss_B_cycled = torch.mean((0 - d_cycled_B)**2)
                d_loss_A2_real = torch.mean((1 - d_real_A2)**2)
                d_loss_B2_real = torch.mean((1 - d_real_B2)**2)
                d_loss_A_2nd = (d_loss_A2_real + d_loss_A_cycled) / 2.0
                d_loss_B_2nd = (d_loss_B2_real + d_loss_B_cycled) / 2.0

                # Final Loss for discriminator with the second step adverserial loss
                d_loss = (d_loss_A + d_loss_B) / 2.0 + \
                    (d_loss_A_2nd + d_loss_B_2nd) / 2.0
                # self.discriminator_loss_store.append(d_loss.item())

                # Backprop for Discriminator
                self.reset_grad()
                d_loss.backward()
                self.discriminator_optimizer.step()

                # if num_iterations % args.steps_per_print == 0:
                #     print(f"Epoch: {epoch} Step: {num_iterations} Generator Loss: {generator_loss.item()} Discriminator Loss: {d_loss.item()}")

                self.logger.log_iter(loss_dict={
                    'g_loss': g_loss.item(),
                    'd_loss': d_loss.item()
                })

                self.logger.end_iter()
                # adjust learning rates
                if self.logger.global_step > self.decay_after:
                    self.identity_loss_lambda = 0
                    self.adjust_lr_rate(self.generator_optimizer,
                                        name='generator')
                    self.adjust_lr_rate(self.generator_optimizer,
                                        name='discriminator')

            if self.logger.epoch % self.epochs_per_plot == 0:
                # Log spectrograms
                real_mel_A_fig = get_mel_spectrogram_fig(
                    real_A[0].detach().cpu())
                fake_mel_A_fig = get_mel_spectrogram_fig(
                    generated_A[0].detach().cpu())
                real_mel_B_fig = get_mel_spectrogram_fig(
                    real_B[0].detach().cpu())
                fake_mel_B_fig = get_mel_spectrogram_fig(
                    generated_B[0].detach().cpu())
                self.logger.visualize_outputs({
                    "real_voc_spec": real_mel_A_fig,
                    "fake_coraal_spec": fake_mel_B_fig,
                    "real_coraal_spec": real_mel_B_fig,
                    "fake_voc_spec": fake_mel_A_fig
                })

                # Decode spec->wav
                real_wav_A = decode_melspectrogram(self.vocoder,
                                                   real_A[0].detach().cpu(),
                                                   self.dataset_A_mean,
                                                   self.dataset_A_std).cpu()
                fake_wav_A = decode_melspectrogram(
                    self.vocoder, generated_A[0].detach().cpu(),
                    self.dataset_A_mean, self.dataset_A_std).cpu()
                real_wav_B = decode_melspectrogram(self.vocoder,
                                                   real_B[0].detach().cpu(),
                                                   self.dataset_B_mean,
                                                   self.dataset_B_std).cpu()
                fake_wav_B = decode_melspectrogram(
                    self.vocoder, generated_B[0].detach().cpu(),
                    self.dataset_B_mean, self.dataset_B_std).cpu()

                # # Log wav
                # real_wav_A_fig = get_waveform_fig(real_wav_A, self.sample_rate)
                # fake_wav_A_fig = get_waveform_fig(fake_wav_A, self.sample_rate)
                # real_wav_B_fig = get_waveform_fig(real_wav_B, self.sample_rate)
                # fake_wav_B_fig = get_waveform_fig(fake_wav_B, self.sample_rate)
                # self.logger.visualize_outputs({"real_voc_wav": real_wav_A_fig, "fake_coraal_wav": fake_wav_B_fig,
                #                                "real_coraal_wav": real_wav_B_fig, "fake_voc_wav": fake_wav_A_fig})

                # Convert spectrograms from validation set to wav and log to tensorboard
                real_mel_full_A, real_mel_full_B = next(
                    iter(self.validation_dataloader))
                real_mel_full_A = real_mel_full_A.to(self.device,
                                                     dtype=torch.float)
                real_mel_full_B = real_mel_full_B.to(self.device,
                                                     dtype=torch.float)
                fake_mel_full_B = self.generator_A2B(
                    real_mel_full_A, torch.ones_like(real_mel_full_A))
                fake_mel_full_A = self.generator_B2A(
                    real_mel_full_B, torch.ones_like(real_mel_full_B))
                real_wav_full_A = decode_melspectrogram(
                    self.vocoder, real_mel_full_A[0].detach().cpu(),
                    self.dataset_A_mean, self.dataset_A_std).cpu()
                fake_wav_full_A = decode_melspectrogram(
                    self.vocoder, fake_mel_full_A[0].detach().cpu(),
                    self.dataset_A_mean, self.dataset_A_std).cpu()
                real_wav_full_B = decode_melspectrogram(
                    self.vocoder, real_mel_full_B[0].detach().cpu(),
                    self.dataset_B_mean, self.dataset_B_std).cpu()
                fake_wav_full_B = decode_melspectrogram(
                    self.vocoder, fake_mel_full_B[0].detach().cpu(),
                    self.dataset_B_mean, self.dataset_B_std).cpu()
                self.logger.log_audio(real_wav_full_A.T, "real_voc_audio",
                                      self.sample_rate)
                self.logger.log_audio(fake_wav_full_A.T, "fake_voc_audio",
                                      self.sample_rate)
                self.logger.log_audio(real_wav_full_B.T, "real_coraal_audio",
                                      self.sample_rate)
                self.logger.log_audio(fake_wav_full_B.T, "fake_coraal_audio",
                                      self.sample_rate)

            if self.logger.epoch % self.epochs_per_save == 0:
                self.saver.save(self.logger.epoch, self.generator_A2B,
                                self.generator_optimizer, None, args.device,
                                "generator_A2B")
                self.saver.save(self.logger.epoch, self.generator_B2A,
                                self.generator_optimizer, None, args.device,
                                "generator_B2A")
                self.saver.save(self.logger.epoch, self.discriminator_A,
                                self.discriminator_optimizer, None,
                                args.device, "discriminator_A")
                self.saver.save(self.logger.epoch, self.discriminator_B,
                                self.discriminator_optimizer, None,
                                args.device, "discriminator_B")
                self.saver.save(self.logger.epoch, self.discriminator_A2,
                                self.discriminator_optimizer, None,
                                args.device, "discriminator_A2")
                self.saver.save(self.logger.epoch, self.discriminator_B2,
                                self.discriminator_optimizer, None,
                                args.device, "discriminator_B2")

            self.logger.end_epoch()