Esempio n. 1
0
 def inference(self, device: str = 'cuda') -> None:
     '''
     Random images for different feature levels are generated and saved
     '''
     # Models to device
     self.generator.to(device)
     self.vgg16.to(device)
     # Generator into eval mode
     self.generator.eval()
     # Get random images form validation dataset
     images = [self.validation_dataset[index].unsqueeze(dim=0).to(device) for index in
               np.random.choice(range(len(self.validation_dataset)), replace=False, size=7)]
     # Get list of masks for different layers
     masks_levels = [get_masks_for_inference(layer, add_batch_size=True, device=device) for layer in range(7)]
     # Init tensor of fake images to store all fake images
     fake_images = torch.empty(7 ** 2, images[0].shape[1], images[0].shape[2], images[0].shape[3],
                               dtype=torch.float32, device=device)
     # Init counter
     counter = 0
     # Loop over all image and masks
     for image in images:
         for masks in masks_levels:
             # Generate fake images
             fake_image = self.generator(
                 input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device),
                 features=self.vgg16(image),
                 masks=masks)
             # Save fake images
             fake_images[counter] = fake_image.squeeze(dim=0)
             # Increment counter
             counter += 1
     # Save tensor as image
     torchvision.utils.save_image(
         misc.normalize_0_1_batch(fake_images),
         os.path.join(self.path_save_plots, 'predictions_{}.png'.format(str(datetime.now()))), nrow=7)
Esempio n. 2
0
 def validate(self, device: str = 'cuda') -> float:
     '''
     FID score gets estimated
     :param plot: (bool) True if samples should be plotted
     :return: (float, float) IS and FID score
     '''
     # Generator into validation mode
     self.generator.eval()
     self.vgg16.eval()
     # Validation samples for plotting to device
     self.validation_latents = self.validation_latents.to(device)
     self.validation_images_to_plot = self.validation_images_to_plot.to(device)
     for index in range(len(self.validation_masks)):
         self.validation_masks[index] = self.validation_masks[index].to(device)
     # Generate images
     fake_image = self.generator(input=self.validation_latents,
                                 features=self.vgg16(self.validation_images_to_plot),
                                 masks=self.validation_masks).cpu()
     # Save images
     torchvision.utils.save_image(misc.normalize_0_1_batch(fake_image),
                                  os.path.join(self.path_save_plots, str(self.progress_bar.n) + '.png'),
                                  nrow=7)
     # Generator back into train mode
     self.generator.train()
     return frechet_inception_distance(dataset_real=self.validation_dataset_fid,
                                       generator=self.generator, vgg16=self.vgg16)
 def inference(self, device: str = 'cuda') -> None:
     '''
     Random images for different feature levels are generated and saved
     '''
     # Models to device
     self.generator.to(device)
     self.vgg16.to(device)
     # Generator into eval mode
     self.generator.eval()
     # Get random images form validation dataset
     images, labels, _ = image_label_list_of_masks_collate_function(
         [self.validation_dataset_fid.dataset[index] for index in
          np.random.choice(range(len(self.validation_dataset_fid)), replace=False, size=7)])
     # Get list of masks for different layers
     masks_levels = [get_masks_for_inference(layer, add_batch_size=True, device=device) for layer in range(7)]
     # Init tensor of fake images to store all fake images
     fake_images = torch.empty(7 ** 2, images.shape[1], images.shape[2], images.shape[3],
                               dtype=torch.float32, device=device)
     # Init counter
     counter = 0
     # Loop over all image and masks
     for image, label in zip(images, labels):
         # Data to device
         image = image.to(device)[None]
         label = label.to(device)[None]
         for masks in masks_levels:
             # Generate fake images
             if isinstance(self.generator, nn.DataParallel):
                 fake_image = self.generator.module(
                     input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device),
                     features=self.vgg16(image),
                     masks=masks,
                     class_id=label.float())
             else:
                 fake_image = self.generator(
                     input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device),
                     features=self.vgg16(image),
                     masks=masks,
                     class_id=label.float())
             # Save fake images
             fake_images[counter] = fake_image.squeeze(dim=0)
             # Increment counter
             counter += 1
     # Save tensor as image
     torchvision.utils.save_image(
         misc.normalize_0_1_batch(fake_images),
         os.path.join(self.path_save_plots, 'predictions_{}.png'.format(self.progress_bar.n)), nrow=7)
     # Back into training mode
     self.generator.train()
Esempio n. 4
0
    def __init__(self,
                 generator: Union[Generator, nn.DataParallel],
                 discriminator: Union[Discriminator, nn.DataParallel],
                 training_dataset: DataLoader,
                 validation_dataset: Dataset,
                 validation_dataset_fid: DataLoader,
                 vgg16: Union[VGG16, nn.DataParallel] = VGG16(),
                 generator_optimizer: torch.optim.Optimizer = None,
                 discriminator_optimizer: torch.optim.Optimizer = None,
                 generator_loss: nn.Module = LSGANGeneratorLoss(),
                 discriminator_loss: nn.Module = LSGANDiscriminatorLoss(),
                 semantic_reconstruction_loss: nn.Module = SemanticReconstructionLoss(),
                 diversity_loss: nn.Module = DiversityLoss(),
                 save_data_path: str = 'saved_data') -> None:
        '''
        Constructor
        :param generator: (nn.Module, nn.DataParallel) Generator network
        :param discriminator: (nn.Module, nn.DataParallel) Discriminator network
        :param training_dataset: (DataLoader) Training dataset
        :param vgg16: (nn.Module, nn.DataParallel) VGG16 module
        :param generator_optimizer: (torch.optim.Optimizer) Optimizer of the generator network
        :param discriminator_optimizer: (torch.optim.Optimizer) Optimizer of the discriminator network
        :param generator_loss: (nn.Module) Generator loss function
        :param discriminator_loss: (nn.Module) Discriminator loss function
        :param semantic_reconstruction_loss: (nn.Module) Semantic reconstruction loss function
        :param diversity_loss: (nn.Module) Diversity loss function
        '''
        # Save parameters
        self.generator = generator
        self.discriminator = discriminator
        self.training_dataset = training_dataset
        self.validation_dataset = validation_dataset
        self.validation_dataset_fid = validation_dataset_fid
        self.vgg16 = vgg16
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_loss = generator_loss
        self.discriminator_loss = discriminator_loss
        self.semantic_reconstruction_loss = semantic_reconstruction_loss
        self.diversity_loss = diversity_loss
        self.latent_dimensions = self.generator.module.latent_dimensions \
            if isinstance(self.generator, nn.DataParallel) else self.generator.latent_dimensions
        # Init logger
        self.logger = Logger()
        # Make directories to save logs, plots and models during training
        time_and_date = str(datetime.now())
        self.path_save_models = os.path.join(save_data_path, 'models_' + time_and_date)
        if not os.path.exists(self.path_save_models):
            os.makedirs(self.path_save_models)
        self.path_save_plots = os.path.join(save_data_path, 'plots_' + time_and_date)
        if not os.path.exists(self.path_save_plots):
            os.makedirs(self.path_save_plots)
        self.path_save_metrics = os.path.join(save_data_path, 'metrics_' + time_and_date)
        if not os.path.exists(self.path_save_metrics):
            os.makedirs(self.path_save_metrics)
        # Make indexes for validation plots
        validation_plot_indexes = np.random.choice(range(len(self.validation_dataset_fid.dataset)), 49, replace=False)
        # Plot and save validation images used to plot generated images
        self.validation_images_to_plot, _, self.validation_masks = image_label_list_of_masks_collate_function(
            [self.validation_dataset_fid.dataset[index] for index in validation_plot_indexes])

        torchvision.utils.save_image(misc.normalize_0_1_batch(self.validation_images_to_plot),
                                     os.path.join(self.path_save_plots, 'validation_images.png'), nrow=7)
        # Plot masks
        torchvision.utils.save_image(self.validation_masks[0],
                                     os.path.join(self.path_save_plots, 'validation_masks.png'),
                                     nrow=7)
        # Generate latents for validation
        self.validation_latents = torch.randn(49, self.latent_dimensions, dtype=torch.float32)
        # Log hyperparameter
        self.logger.hyperparameter['generator'] = str(self.generator)
        self.logger.hyperparameter['discriminator'] = str(self.discriminator)
        self.logger.hyperparameter['vgg16'] = str(self.vgg16)
        self.logger.hyperparameter['generator_optimizer'] = str(self.generator_optimizer)
        self.logger.hyperparameter['discriminator_optimizer'] = str(self.discriminator_optimizer)
        self.logger.hyperparameter['generator_loss'] = str(self.generator_loss)
        self.logger.hyperparameter['discriminator_loss'] = str(self.discriminator_loss)
        self.logger.hyperparameter['diversity_loss'] = str(self.diversity_loss)
        self.logger.hyperparameter['discriminator_loss'] = str(self.semantic_reconstruction_loss)