Beispiel #1
0
    def __init__(self, opts, prev_train_checkpoint=None):
        self.opts = opts

        self.global_step = 0

        self.device = 'cuda:0'
        self.opts.device = self.device

        # Initialize network
        self.net = e4e(self.opts).to(self.device)

        # Estimate latent_avg via dense sampling if latent_avg is not available
        if self.net.latent_avg is None:
            self.net.latent_avg = self.net.decoder.mean_latent(
                int(1e5))[0].detach()

        # get the image corresponding to the latent average
        self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
                                  input_code=True,
                                  randomize_noise=False,
                                  return_latents=False,
                                  average_code=True)[0]
        self.avg_image = self.avg_image.to(self.device).float().detach()
        if self.opts.dataset_type == "cars_encode":
            self.avg_image = self.avg_image[:, 32:224, :]
        common.tensor2im(self.avg_image).save(
            os.path.join(self.opts.exp_dir, 'avg_image.jpg'))

        # Initialize loss
        if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
            raise ValueError(
                'Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!'
            )
        self.mse_loss = nn.MSELoss().to(self.device).eval()
        if self.opts.lpips_lambda > 0:
            self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
        if self.opts.id_lambda > 0:
            self.id_loss = id_loss.IDLoss().to(self.device).eval()
        if self.opts.moco_lambda > 0:
            self.moco_loss = moco_loss.MocoLoss()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize discriminator
        if self.opts.w_discriminator_lambda > 0:
            self.discriminator = LatentCodesDiscriminator(512,
                                                          4).to(self.device)
            self.discriminator_optimizer = torch.optim.Adam(
                list(self.discriminator.parameters()),
                lr=opts.w_discriminator_lr)
            self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
            self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(
                                              self.opts.test_workers),
                                          drop_last=True)

        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps

        if prev_train_checkpoint is not None:
            self.load_from_train_checkpoint(prev_train_checkpoint)
            prev_train_checkpoint = None
Beispiel #2
0
class Coach:
    def __init__(self, opts, prev_train_checkpoint=None):
        self.opts = opts

        self.global_step = 0

        self.device = 'cuda:0'
        self.opts.device = self.device

        # Initialize network
        self.net = e4e(self.opts).to(self.device)

        # Estimate latent_avg via dense sampling if latent_avg is not available
        if self.net.latent_avg is None:
            self.net.latent_avg = self.net.decoder.mean_latent(
                int(1e5))[0].detach()

        # get the image corresponding to the latent average
        self.avg_image = self.net(self.net.latent_avg.unsqueeze(0),
                                  input_code=True,
                                  randomize_noise=False,
                                  return_latents=False,
                                  average_code=True)[0]
        self.avg_image = self.avg_image.to(self.device).float().detach()
        if self.opts.dataset_type == "cars_encode":
            self.avg_image = self.avg_image[:, 32:224, :]
        common.tensor2im(self.avg_image).save(
            os.path.join(self.opts.exp_dir, 'avg_image.jpg'))

        # Initialize loss
        if self.opts.id_lambda > 0 and self.opts.moco_lambda > 0:
            raise ValueError(
                'Both ID and MoCo loss have lambdas > 0! Please select only one to have non-zero lambda!'
            )
        self.mse_loss = nn.MSELoss().to(self.device).eval()
        if self.opts.lpips_lambda > 0:
            self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
        if self.opts.id_lambda > 0:
            self.id_loss = id_loss.IDLoss().to(self.device).eval()
        if self.opts.moco_lambda > 0:
            self.moco_loss = moco_loss.MocoLoss()

        # Initialize optimizer
        self.optimizer = self.configure_optimizers()

        # Initialize discriminator
        if self.opts.w_discriminator_lambda > 0:
            self.discriminator = LatentCodesDiscriminator(512,
                                                          4).to(self.device)
            self.discriminator_optimizer = torch.optim.Adam(
                list(self.discriminator.parameters()),
                lr=opts.w_discriminator_lr)
            self.real_w_pool = LatentCodesPool(self.opts.w_pool_size)
            self.fake_w_pool = LatentCodesPool(self.opts.w_pool_size)

        # Initialize dataset
        self.train_dataset, self.test_dataset = self.configure_datasets()
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.opts.batch_size,
                                           shuffle=True,
                                           num_workers=int(self.opts.workers),
                                           drop_last=True)
        self.test_dataloader = DataLoader(self.test_dataset,
                                          batch_size=self.opts.test_batch_size,
                                          shuffle=False,
                                          num_workers=int(
                                              self.opts.test_workers),
                                          drop_last=True)

        # Initialize logger
        log_dir = os.path.join(opts.exp_dir, 'logs')
        os.makedirs(log_dir, exist_ok=True)
        self.logger = SummaryWriter(log_dir=log_dir)

        # Initialize checkpoint dir
        self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.best_val_loss = None
        if self.opts.save_interval is None:
            self.opts.save_interval = self.opts.max_steps

        if prev_train_checkpoint is not None:
            self.load_from_train_checkpoint(prev_train_checkpoint)
            prev_train_checkpoint = None

    def load_from_train_checkpoint(self, ckpt):
        print('Loading previous training data...')
        self.global_step = ckpt['global_step'] + 1
        self.best_val_loss = ckpt['best_val_loss']
        self.net.load_state_dict(ckpt['state_dict'])
        if self.opts.w_discriminator_lambda > 0:
            self.discriminator.load_state_dict(
                ckpt['discriminator_state_dict'])
            self.discriminator_optimizer.load_state_dict(
                ckpt['discriminator_optimizer_state_dict'])
        if self.opts.progressive_steps:
            self.check_for_progressive_training_update(
                is_resume_from_ckpt=True)
        print(f'Resuming training from step {self.global_step}')

    def compute_discriminator_loss(self, x):
        avg_image_for_batch = self.avg_image.unsqueeze(0).repeat(
            x.shape[0], 1, 1, 1)
        avg_image_for_batch.clone().detach().requires_grad_(True)
        x_input = torch.cat([x, avg_image_for_batch], dim=1)
        disc_loss_dict = {}
        if self.is_training_discriminator():
            disc_loss_dict = self.train_discriminator(x_input)
        return disc_loss_dict

    def perform_train_iteration_on_batch(self, x, y):
        y_hat, latent = None, None
        loss_dict, id_logs = None, None
        y_hats = {idx: [] for idx in range(x.shape[0])}
        for iter in range(self.opts.n_iters_per_batch):
            if iter == 0:
                avg_image_for_batch = self.avg_image.unsqueeze(0).repeat(
                    x.shape[0], 1, 1, 1)
                x_input = torch.cat([x, avg_image_for_batch], dim=1)
                y_hat, latent = self.net.forward(x_input,
                                                 latent=None,
                                                 return_latents=True)
            else:
                y_hat_clone = y_hat.clone().detach().requires_grad_(True)
                latent_clone = latent.clone().detach().requires_grad_(True)
                x_input = torch.cat([x, y_hat_clone], dim=1)
                y_hat, latent = self.net.forward(x_input,
                                                 latent=latent_clone,
                                                 return_latents=True)

            if self.opts.dataset_type == "cars_encode":
                y_hat = y_hat[:, :, 32:224, :]

            loss, loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
            loss.backward()
            # store intermediate outputs
            for idx in range(x.shape[0]):
                y_hats[idx].append([y_hat[idx], id_logs[idx]['diff_target']])

        return y_hats, loss_dict, id_logs

    def train(self):
        self.net.train()
        if self.opts.progressive_steps:
            self.check_for_progressive_training_update()

        while self.global_step < self.opts.max_steps:
            for batch_idx, batch in enumerate(self.train_dataloader):
                x, y = batch
                x, y = x.to(self.device).float(), y.to(self.device).float()

                disc_loss_dict = self.compute_discriminator_loss(x)

                self.optimizer.zero_grad()
                y_hats, encoder_loss_dict, id_logs = self.perform_train_iteration_on_batch(
                    x, y)
                self.optimizer.step()

                loss_dict = {**disc_loss_dict, **encoder_loss_dict}

                # Logging related
                if self.global_step % self.opts.image_interval == 0 or (
                        self.global_step < 1000
                        and self.global_step % 25 == 0):
                    self.parse_and_log_images(id_logs,
                                              x,
                                              y,
                                              y_hats,
                                              title='images/train')

                if self.global_step % self.opts.board_interval == 0:
                    self.print_metrics(loss_dict, prefix='train')
                    self.log_metrics(loss_dict, prefix='train')

                # Validation related
                val_loss_dict = None
                if self.global_step % self.opts.val_interval == 0 or self.global_step == self.opts.max_steps:
                    val_loss_dict = self.validate()
                    if val_loss_dict and (
                            self.best_val_loss is None
                            or val_loss_dict['loss'] < self.best_val_loss):
                        self.best_val_loss = val_loss_dict['loss']
                        self.checkpoint_me(val_loss_dict, is_best=True)

                if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
                    if val_loss_dict is not None:
                        self.checkpoint_me(val_loss_dict, is_best=False)
                    else:
                        self.checkpoint_me(loss_dict, is_best=False)

                if self.global_step == self.opts.max_steps:
                    print('OMG, finished training!')
                    break

                self.global_step += 1
                if self.opts.progressive_steps:
                    self.check_for_progressive_training_update()

    def perform_val_iteration_on_batch(self, x, y):
        y_hat, latent = None, None
        cur_loss_dict, id_logs = None, None
        y_hats = {idx: [] for idx in range(x.shape[0])}
        for iter in range(self.opts.n_iters_per_batch):
            if iter == 0:
                avg_image_for_batch = self.avg_image.unsqueeze(0).repeat(
                    x.shape[0], 1, 1, 1)
                x_input = torch.cat([x, avg_image_for_batch], dim=1)
            else:
                x_input = torch.cat([x, y_hat], dim=1)

            y_hat, latent = self.net.forward(x_input,
                                             latent=latent,
                                             return_latents=True)
            if self.opts.dataset_type == "cars_encode":
                y_hat = y_hat[:, :, 32:224, :]

            loss, cur_loss_dict, id_logs = self.calc_loss(x, y, y_hat, latent)
            # store intermediate outputs
            for idx in range(x.shape[0]):
                y_hats[idx].append([y_hat[idx], id_logs[idx]['diff_target']])

        return y_hats, cur_loss_dict, id_logs

    def validate(self):
        self.net.eval()
        agg_loss_dict = []
        for batch_idx, batch in enumerate(self.test_dataloader):
            x, y = batch
            x, y = x.to(self.device).float(), y.to(self.device).float()

            # validate discriminator on batch
            avg_image_for_batch = self.avg_image.unsqueeze(0).repeat(
                x.shape[0], 1, 1, 1)
            x_input = torch.cat([x, avg_image_for_batch], dim=1)
            cur_disc_loss_dict = {}
            if self.is_training_discriminator():
                cur_disc_loss_dict = self.validate_discriminator(x_input)

            # validate encoder on batch
            with torch.no_grad():
                y_hats, cur_enc_loss_dict, id_logs = self.perform_val_iteration_on_batch(
                    x, y)

            cur_loss_dict = {**cur_disc_loss_dict, **cur_enc_loss_dict}
            agg_loss_dict.append(cur_loss_dict)

            # Logging related
            self.parse_and_log_images(id_logs,
                                      x,
                                      y,
                                      y_hats,
                                      title='images/test',
                                      subscript='{:04d}'.format(batch_idx))

            # For first step just do sanity test on small amount of data
            if self.global_step == 0 and batch_idx >= 4:
                self.net.train()
                return None  # Do not log, inaccurate in first batch

        loss_dict = train_utils.aggregate_loss_dict(agg_loss_dict)
        self.log_metrics(loss_dict, prefix='test')
        self.print_metrics(loss_dict, prefix='test')

        self.net.train()
        return loss_dict

    def checkpoint_me(self, loss_dict, is_best):
        save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(
            self.global_step)
        save_dict = self.__get_save_dict()
        checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
        torch.save(save_dict, checkpoint_path)
        with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'),
                  'a') as f:
            if is_best:
                f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(
                    self.global_step, self.best_val_loss, loss_dict))
            else:
                f.write('Step - {}, \n{}\n'.format(self.global_step,
                                                   loss_dict))

    def configure_optimizers(self):
        params = list(self.net.encoder.parameters())
        if self.opts.train_decoder:
            params += list(self.net.decoder.parameters())
        else:
            self.requires_grad(self.net.decoder, False)
        if self.opts.optim_name == 'adam':
            optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
        else:
            optimizer = Ranger(params, lr=self.opts.learning_rate)
        return optimizer

    def configure_datasets(self):
        if self.opts.dataset_type not in data_configs.DATASETS.keys():
            raise Exception('{} is not a valid dataset_type'.format(
                self.opts.dataset_type))
        print('Loading dataset for {}'.format(self.opts.dataset_type))
        dataset_args = data_configs.DATASETS[self.opts.dataset_type]
        transforms_dict = dataset_args['transforms'](
            self.opts).get_transforms()
        train_dataset = ImagesDataset(
            source_root=dataset_args['train_source_root'],
            target_root=dataset_args['train_target_root'],
            source_transform=transforms_dict['transform_source'],
            target_transform=transforms_dict['transform_gt_train'],
            opts=self.opts)
        test_dataset = ImagesDataset(
            source_root=dataset_args['test_source_root'],
            target_root=dataset_args['test_target_root'],
            source_transform=transforms_dict['transform_source'],
            target_transform=transforms_dict['transform_test'],
            opts=self.opts)
        print("Number of training samples: {}".format(len(train_dataset)))
        print("Number of test samples: {}".format(len(test_dataset)))
        return train_dataset, test_dataset

    def calc_loss(self, x, y, y_hat, latent):
        loss_dict = {}
        loss = 0.0
        id_logs = None

        # Adversarial loss
        if self.is_training_discriminator():
            loss_disc = self.compute_adversarial_loss(latent, loss_dict)
            loss += self.opts.w_discriminator_lambda * loss_disc

        # delta regularization loss
        if self.opts.progressive_steps and self.net.encoder.progressive_stage.value != 18:
            total_delta_loss = self.compute_delta_regularization_loss(
                latent, loss_dict)
            loss += self.opts.delta_norm_lambda * total_delta_loss

        # similarity losses
        if self.opts.id_lambda > 0:
            loss_id, sim_improvement, id_logs = self.id_loss(y_hat, y, x)
            loss_dict['loss_id'] = float(loss_id)
            loss_dict['id_improve'] = float(sim_improvement)
            loss += loss_id * self.opts.id_lambda
        if self.opts.l2_lambda > 0:
            loss_l2 = F.mse_loss(y_hat, y)
            loss_dict['loss_l2'] = float(loss_l2)
            loss += loss_l2 * self.opts.l2_lambda
        if self.opts.lpips_lambda > 0:
            loss_lpips = self.lpips_loss(y_hat, y)
            loss_dict['loss_lpips'] = float(loss_lpips)
            loss += loss_lpips * self.opts.lpips_lambda
        if self.opts.moco_lambda > 0:
            loss_moco, sim_improvement, id_logs = self.moco_loss(y_hat, y, x)
            loss_dict['loss_moco'] = float(loss_moco)
            loss_dict['id_improve'] = float(sim_improvement)
            loss += loss_moco * self.opts.moco_lambda

        loss_dict['loss'] = float(loss)
        return loss, loss_dict, id_logs

    def compute_adversarial_loss(self, latent, loss_dict):
        loss_disc = 0.
        dims_to_discriminate = self.get_dims_to_discriminate() if self.is_progressive_training() else \
         list(range(self.net.decoder.n_latent))
        for i in dims_to_discriminate:
            w = latent[:, i, :]
            fake_pred = self.discriminator(w)
            loss_disc += F.softplus(-fake_pred).mean()
        loss_disc /= len(dims_to_discriminate)
        loss_dict['encoder_discriminator_loss'] = float(loss_disc)
        return loss_disc

    def compute_delta_regularization_loss(self, latent, loss_dict):
        total_delta_loss = 0
        deltas_latent_dims = self.net.encoder.get_deltas_starting_dimensions()
        first_w = latent[:, 0, :]
        for i in range(1, self.net.encoder.progressive_stage.value + 1):
            curr_dim = deltas_latent_dims[i]
            delta = latent[:, curr_dim, :] - first_w
            delta_loss = torch.norm(delta, self.opts.delta_norm, dim=1).mean()
            loss_dict[f"delta{i}_loss"] = float(delta_loss)
            total_delta_loss += delta_loss
        loss_dict['total_delta_loss'] = float(total_delta_loss)
        return total_delta_loss

    def log_metrics(self, metrics_dict, prefix):
        for key, value in metrics_dict.items():
            self.logger.add_scalar('{}/{}'.format(prefix, key), value,
                                   self.global_step)

    def print_metrics(self, metrics_dict, prefix):
        print('Metrics for {}, step {}'.format(prefix, self.global_step))
        for key, value in metrics_dict.items():
            print('\t{} = '.format(key), value)

    def parse_and_log_images(self,
                             id_logs,
                             x,
                             y,
                             y_hat,
                             title,
                             subscript=None,
                             display_count=2):
        im_data = []
        for i in range(display_count):
            if type(y_hat) == dict:
                output_face = [[
                    common.tensor2im(y_hat[i][iter_idx][0]),
                    y_hat[i][iter_idx][1]
                ] for iter_idx in range(len(y_hat[i]))]
            else:
                output_face = [common.tensor2im(y_hat[i])]
            cur_im_data = {
                'input_face': common.tensor2im(x[i]),
                'target_face': common.tensor2im(y[i]),
                'output_face': output_face,
            }
            if id_logs is not None:
                for key in id_logs[i]:
                    cur_im_data[key] = id_logs[i][key]
            im_data.append(cur_im_data)
        self.log_images(title, im_data=im_data, subscript=subscript)

    def log_images(self, name, im_data, subscript=None, log_latest=False):
        fig = common.vis_faces(im_data)
        step = self.global_step
        if log_latest:
            step = 0
        if subscript:
            path = os.path.join(self.logger.log_dir, name,
                                '{}_{:04d}.jpg'.format(subscript, step))
        else:
            path = os.path.join(self.logger.log_dir, name,
                                '{:04d}.jpg'.format(step))
        os.makedirs(os.path.dirname(path), exist_ok=True)
        fig.savefig(path)
        plt.close(fig)

    def __get_save_dict(self):
        save_dict = {
            'state_dict': self.net.state_dict(),
            'opts': vars(self.opts),
            'global_step': self.global_step,
            'optimizer': self.optimizer.state_dict(),
            'best_val_loss': self.best_val_loss,
            'latent_avg': self.net.latent_avg
        }
        if self.opts.w_discriminator_lambda > 0:
            save_dict[
                'discriminator_state_dict'] = self.discriminator.state_dict()
            save_dict[
                'discriminator_optimizer_state_dict'] = self.discriminator_optimizer.state_dict(
                )
        return save_dict

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Util Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    def get_dims_to_discriminate(self):
        deltas_starting_dimensions = self.net.encoder.get_deltas_starting_dimensions(
        )
        return deltas_starting_dimensions[:self.net.encoder.progressive_stage.
                                          value + 1]

    def is_progressive_training(self):
        return self.opts.progressive_steps is not None

    def check_for_progressive_training_update(self, is_resume_from_ckpt=False):
        for i in range(len(self.opts.progressive_steps)):
            if is_resume_from_ckpt and self.global_step >= self.opts.progressive_steps[
                    i]:  # Case checkpoint
                self.net.encoder.set_progressive_stage(ProgressiveStage(i))
            if self.global_step == self.opts.progressive_steps[
                    i]:  # Case training reached progressive step
                self.net.encoder.set_progressive_stage(ProgressiveStage(i))

    @staticmethod
    def requires_grad(model, flag=True):
        for p in model.parameters():
            p.requires_grad = flag

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Discriminator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    def is_training_discriminator(self):
        return self.opts.w_discriminator_lambda > 0

    @staticmethod
    def discriminator_loss(real_pred, fake_pred, loss_dict):
        real_loss = F.softplus(-real_pred).mean()
        fake_loss = F.softplus(fake_pred).mean()
        loss_dict['d_real_loss'] = float(real_loss)
        loss_dict['d_fake_loss'] = float(fake_loss)
        return real_loss + fake_loss

    @staticmethod
    def discriminator_r1_loss(real_pred, real_w):
        grad_real, = autograd.grad(outputs=real_pred.sum(),
                                   inputs=real_w,
                                   create_graph=True)
        grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0],
                                                -1).sum(1).mean()
        return grad_penalty

    def train_discriminator(self, x):
        loss_dict = {}
        self.requires_grad(self.discriminator, True)

        with torch.no_grad():
            real_w, fake_w = self.sample_real_and_fake_latents(x)
        real_pred = self.discriminator(real_w)
        fake_pred = self.discriminator(fake_w)
        loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
        loss_dict['discriminator_loss'] = float(loss)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()

        # r1 regularization
        d_regularize = self.global_step % self.opts.d_reg_every == 0
        if d_regularize:
            real_w = real_w.detach()
            real_w.requires_grad = True
            real_pred = self.discriminator(real_w)
            r1_loss = self.discriminator_r1_loss(real_pred, real_w)

            self.discriminator.zero_grad()
            r1_final_loss = self.opts.r1 / 2 * r1_loss * self.opts.d_reg_every + 0 * real_pred[
                0]
            r1_final_loss.backward()
            self.discriminator_optimizer.step()
            loss_dict['discriminator_r1_loss'] = float(r1_final_loss)

        # Reset to previous state
        self.requires_grad(self.discriminator, False)

        return loss_dict

    def validate_discriminator(self, x):
        with torch.no_grad():
            loss_dict = {}
            real_w, fake_w = self.sample_real_and_fake_latents(x)
            real_pred = self.discriminator(real_w)
            fake_pred = self.discriminator(fake_w)
            loss = self.discriminator_loss(real_pred, fake_pred, loss_dict)
            loss_dict['discriminator_loss'] = float(loss)
            return loss_dict

    def sample_real_and_fake_latents(self, x):
        sample_z = torch.randn(self.opts.batch_size, 512, device=self.device)
        real_w = self.net.decoder.get_latent(sample_z)
        fake_w = self.net.encoder(x)
        if self.is_progressive_training(
        ):  # When progressive training, feed only unique w's
            dims_to_discriminate = self.get_dims_to_discriminate()
            fake_w = fake_w[:, dims_to_discriminate, :]
        if self.opts.use_w_pool:
            real_w = self.real_w_pool.query(real_w)
            fake_w = self.fake_w_pool.query(fake_w)
        if fake_w.ndim == 3:
            fake_w = fake_w[:, 0, :]
        return real_w, fake_w