示例#1
0
class Model(object):
    def __init__(self, opt):
        super(Model, self).__init__()

        # Generator
        self.gen = Generator(opt).cuda(opt.gpu_id)

        self.gen_params = self.gen.parameters()

        num_params = 0
        for p in self.gen.parameters():
            num_params += p.numel()
        print(self.gen)
        print(num_params)

        # Discriminator
        self.dis = Discriminator(opt).cuda(opt.gpu_id)

        self.dis_params = self.dis.parameters()

        num_params = 0
        for p in self.dis.parameters():
            num_params += p.numel()
        print(self.dis)
        print(num_params)

        # Regressor
        if opt.mse_weight:
            self.reg = torch.load('data/utils/classifier.pth').cuda(
                opt.gpu_id).eval()
        else:
            self.reg = None

        # Losses
        self.criterion_gan = GANLoss(opt, self.dis)
        self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight

        self.loss_mse = Variable(torch.zeros(1).cuda())
        self.loss_adv = Variable(torch.zeros(1).cuda())
        self.loss = Variable(torch.zeros(1).cuda())

        self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/'
        self.gpu_id = opt.gpu_id
        self.noise_channels = opt.in_channels - len(opt.input_idx.split(','))

    def forward(self, inputs):

        input, input_orig, target = inputs

        self.input = Variable(input.cuda(self.gpu_id))
        self.input_orig = Variable(input_orig.cuda(self.gpu_id))
        self.target = Variable(target.cuda(self.gpu_id))

        noise = Variable(
            torch.randn(self.input.size(0),
                        self.noise_channels).cuda(self.gpu_id))

        self.fake = self.gen(torch.cat([self.input, noise], 1))

    def backward_G(self):

        # Regressor loss
        if self.reg is not None:

            fake_input = self.reg(self.fake)

            self.loss_mse = self.criterion_mse(fake_input, self.input_orig)

        # GAN loss
        loss_adv, _ = self.criterion_gan(self.fake)

        loss_G = self.loss_mse + loss_adv
        loss_G.backward()

    def backward_D(self):

        loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake)

        loss_D = loss_adv
        loss_D.backward()

    def train(self):

        self.gen.train()
        self.dis.train()

    def eval(self):

        self.gen.eval()
        self.dis.eval()

    def save_checkpoint(self, epoch):

        torch.save(
            {
                'epoch': epoch,
                'gen_state_dict': self.gen.state_dict(),
                'dis_state_dict': self.dis.state_dict()
            }, self.path + '%d.pkl' % epoch)

    def load_checkpoint(self, path, pretrained=True):

        weights = torch.load(path)

        self.gen.load_state_dict(weights['gen_state_dict'])
        self.dis.load_state_dict(weights['dis_state_dict'])
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
示例#3
0
class WGanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(WGanTrainer, self).__init__(train_loader, test_loader,
                                          valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.generator_time_criterion = nn.MSELoss()

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_time = trainer_args.lambda_time

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Overrides losses from parent class
        self.train_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.test_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.valid_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }

        # Select either wgan or wgan-gp method
        self.use_penalty = trainer_args.use_penalty
        self.gamma = trainer_args.gamma_wgan_gp
        self.clipping_limit = trainer_args.clipping_limit
        self.n_critic = trainer_args.n_critic
        self.coupling_epoch = trainer_args.coupling_epoch

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def compute_gradient_penalty(self, input_batch, generated_batch):
        """
        Compute the gradient penalty as described in the original paper
        (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf).
        :param input_batch: batch of input data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: penalty as a scalar (torch tensor).
        """
        batch_size = input_batch.size(0)
        epsilon = torch.rand(batch_size, 1, 1)
        epsilon = epsilon.expand_as(input_batch).to(self.device)

        # Interpolate
        interpolation = epsilon * input_batch.data + (
            1 - epsilon) * generated_batch.data
        interpolation = interpolation.requires_grad_(True).to(self.device)

        # Computes the discriminator's prediction for the interpolated input
        interpolation_logits = self.discriminator(interpolation)

        # Computes a vector of outputs to make it works with 2 output classes if needed
        grad_outputs = torch.ones_like(interpolation_logits).to(
            self.device).requires_grad_(True)

        # Get the gradients and retain the graph so that the penalty can be back-propagated
        gradients = autograd.grad(outputs=interpolation_logits,
                                  inputs=interpolation,
                                  grad_outputs=grad_outputs,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]
        gradients = gradients.view(batch_size, -1)

        # Computes the norm of the gradients
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1))
        return ((gradients_norm - 1)**2).mean()

    def train_discriminator_step(self, input_batch, target_batch):
        """
        Trains the discriminator for a single step based on the wasserstein gan-gp framework.
        :param input_batch: batch of input data (torch tensor).
        :param target_batch: batch of target data (torch tensor).
        :return: a batch of generated data (torch tensor).
        """
        # Activate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=True)

        # Set the discriminator's gradients to zero
        self.discriminator_optimizer.zero_grad()

        # Generate a batch and compute the penalty
        generated_batch = self.generator(input_batch)

        # Compute the loss
        loss_d = self.discriminator(generated_batch.detach()).mean(
        ) - self.discriminator(target_batch).mean()
        self.train_losses['discriminator']['adversarial'].append(loss_d.item())
        if self.use_penalty:
            penalty = self.compute_gradient_penalty(input_batch,
                                                    generated_batch.detach())
            self.train_losses['discriminator']['penalty'].append(
                penalty.item())
            loss_d = loss_d + self.gamma * penalty

        # Update the discriminator's weights
        loss_d.backward()
        self.discriminator_optimizer.step()

        # Apply the weight constraint if needed
        if not self.use_penalty:
            for p in self.discriminator.parameters():
                p.data.clamp_(min=-self.clipping_limit,
                              max=self.clipping_limit)

        # Return the generated batch to avoid redundant computation
        return generated_batch

    def train_generator_step(self, target_batch, generated_batch):
        """
        Trains the generator for a single step based on the wasserstein gan-gp framework.
        :param target_batch: batch of target data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: None
        """
        # Deactivate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=False)

        # Set generator's gradients to zero
        self.generator_optimizer.zero_grad()

        # Get the generator losses
        loss_g_adversarial = -self.discriminator(generated_batch).mean()
        loss_g_time = self.generator_time_criterion(generated_batch,
                                                    target_batch)

        # Combine the different losses
        loss_g = loss_g_time
        if self.epoch >= self.coupling_epoch:
            loss_g = loss_g + self.lambda_adv * loss_g_adversarial

        # Back-propagate and update the generator weights
        loss_g.backward()
        self.generator_optimizer.step()

        # Store the losses
        self.train_losses['generator']['time_l2'].append(loss_g_time.item())
        self.train_losses['generator']['adversarial'].append(
            loss_g_adversarial.item())

    def change_discriminator_grad_requirement(self, requires_grad):
        """
        Changes the requires_grad flag of discriminator's parameters. This action is not absolutely needed as the
        discriminator's optimizer is not called after the generators update, but it reduces the computational cost.
        :param requires_grad: flag indicating if the discriminator's parameter require gradient tracking (boolean).
        :return: None
        """
        for p in self.discriminator.parameters():
            p.requires_grad_(requires_grad)

    def train(self, epochs):
        """
        Trains the WGAN-GP for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        self.generator.train()
        self.discriminator.train()
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Train the discriminator
                generated_batch = self.train_discriminator_step(
                    input_batch, target_batch)

                # Train the generator every n_critic
                if not (i % self.n_critic):
                    self.train_generator_step(target_batch, generated_batch)

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Penalty: {}\n' \
                              '\t\t Adversarial: {} \n'.format(i,
                                                               self.train_losses['generator']['time_l2'][-1],
                                                               self.train_losses['generator']['adversarial'][-1],
                                                               self.train_losses['discriminator']['penalty'][-1],
                                                               self.train_losses['discriminator']['adversarial'][-1])
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        # Set the models in evaluation mode
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            loss_g_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_g_time.item())

        # Store the validation losses
        self.valid_losses['generator']['time_l2'].append(
            np.mean(batch_losses['time_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])))
        print(message)

        # Set the models in train mode
        self.generator.train()
        self.discriminator.eval()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        savepath = self.savepath.split('.')[0] + '_' + str(
            self.epoch // 5) + '.' + self.savepath.split('.')[1]
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        # self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        # self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)