def inference(self, device: str = 'cuda') -> None: ''' Random images for different feature levels are generated and saved ''' # Models to device self.generator.to(device) self.vgg16.to(device) # Generator into eval mode self.generator.eval() # Get random images form validation dataset images = [self.validation_dataset[index].unsqueeze(dim=0).to(device) for index in np.random.choice(range(len(self.validation_dataset)), replace=False, size=7)] # Get list of masks for different layers masks_levels = [get_masks_for_inference(layer, add_batch_size=True, device=device) for layer in range(7)] # Init tensor of fake images to store all fake images fake_images = torch.empty(7 ** 2, images[0].shape[1], images[0].shape[2], images[0].shape[3], dtype=torch.float32, device=device) # Init counter counter = 0 # Loop over all image and masks for image in images: for masks in masks_levels: # Generate fake images fake_image = self.generator( input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device), features=self.vgg16(image), masks=masks) # Save fake images fake_images[counter] = fake_image.squeeze(dim=0) # Increment counter counter += 1 # Save tensor as image torchvision.utils.save_image( misc.normalize_0_1_batch(fake_images), os.path.join(self.path_save_plots, 'predictions_{}.png'.format(str(datetime.now()))), nrow=7)
def validate(self, device: str = 'cuda') -> float: ''' FID score gets estimated :param plot: (bool) True if samples should be plotted :return: (float, float) IS and FID score ''' # Generator into validation mode self.generator.eval() self.vgg16.eval() # Validation samples for plotting to device self.validation_latents = self.validation_latents.to(device) self.validation_images_to_plot = self.validation_images_to_plot.to(device) for index in range(len(self.validation_masks)): self.validation_masks[index] = self.validation_masks[index].to(device) # Generate images fake_image = self.generator(input=self.validation_latents, features=self.vgg16(self.validation_images_to_plot), masks=self.validation_masks).cpu() # Save images torchvision.utils.save_image(misc.normalize_0_1_batch(fake_image), os.path.join(self.path_save_plots, str(self.progress_bar.n) + '.png'), nrow=7) # Generator back into train mode self.generator.train() return frechet_inception_distance(dataset_real=self.validation_dataset_fid, generator=self.generator, vgg16=self.vgg16)
def inference(self, device: str = 'cuda') -> None: ''' Random images for different feature levels are generated and saved ''' # Models to device self.generator.to(device) self.vgg16.to(device) # Generator into eval mode self.generator.eval() # Get random images form validation dataset images, labels, _ = image_label_list_of_masks_collate_function( [self.validation_dataset_fid.dataset[index] for index in np.random.choice(range(len(self.validation_dataset_fid)), replace=False, size=7)]) # Get list of masks for different layers masks_levels = [get_masks_for_inference(layer, add_batch_size=True, device=device) for layer in range(7)] # Init tensor of fake images to store all fake images fake_images = torch.empty(7 ** 2, images.shape[1], images.shape[2], images.shape[3], dtype=torch.float32, device=device) # Init counter counter = 0 # Loop over all image and masks for image, label in zip(images, labels): # Data to device image = image.to(device)[None] label = label.to(device)[None] for masks in masks_levels: # Generate fake images if isinstance(self.generator, nn.DataParallel): fake_image = self.generator.module( input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device), features=self.vgg16(image), masks=masks, class_id=label.float()) else: fake_image = self.generator( input=torch.randn(1, self.latent_dimensions, dtype=torch.float32, device=device), features=self.vgg16(image), masks=masks, class_id=label.float()) # Save fake images fake_images[counter] = fake_image.squeeze(dim=0) # Increment counter counter += 1 # Save tensor as image torchvision.utils.save_image( misc.normalize_0_1_batch(fake_images), os.path.join(self.path_save_plots, 'predictions_{}.png'.format(self.progress_bar.n)), nrow=7) # Back into training mode self.generator.train()
def __init__(self, generator: Union[Generator, nn.DataParallel], discriminator: Union[Discriminator, nn.DataParallel], training_dataset: DataLoader, validation_dataset: Dataset, validation_dataset_fid: DataLoader, vgg16: Union[VGG16, nn.DataParallel] = VGG16(), generator_optimizer: torch.optim.Optimizer = None, discriminator_optimizer: torch.optim.Optimizer = None, generator_loss: nn.Module = LSGANGeneratorLoss(), discriminator_loss: nn.Module = LSGANDiscriminatorLoss(), semantic_reconstruction_loss: nn.Module = SemanticReconstructionLoss(), diversity_loss: nn.Module = DiversityLoss(), save_data_path: str = 'saved_data') -> None: ''' Constructor :param generator: (nn.Module, nn.DataParallel) Generator network :param discriminator: (nn.Module, nn.DataParallel) Discriminator network :param training_dataset: (DataLoader) Training dataset :param vgg16: (nn.Module, nn.DataParallel) VGG16 module :param generator_optimizer: (torch.optim.Optimizer) Optimizer of the generator network :param discriminator_optimizer: (torch.optim.Optimizer) Optimizer of the discriminator network :param generator_loss: (nn.Module) Generator loss function :param discriminator_loss: (nn.Module) Discriminator loss function :param semantic_reconstruction_loss: (nn.Module) Semantic reconstruction loss function :param diversity_loss: (nn.Module) Diversity loss function ''' # Save parameters self.generator = generator self.discriminator = discriminator self.training_dataset = training_dataset self.validation_dataset = validation_dataset self.validation_dataset_fid = validation_dataset_fid self.vgg16 = vgg16 self.generator_optimizer = generator_optimizer self.discriminator_optimizer = discriminator_optimizer self.generator_loss = generator_loss self.discriminator_loss = discriminator_loss self.semantic_reconstruction_loss = semantic_reconstruction_loss self.diversity_loss = diversity_loss self.latent_dimensions = self.generator.module.latent_dimensions \ if isinstance(self.generator, nn.DataParallel) else self.generator.latent_dimensions # Init logger self.logger = Logger() # Make directories to save logs, plots and models during training time_and_date = str(datetime.now()) self.path_save_models = os.path.join(save_data_path, 'models_' + time_and_date) if not os.path.exists(self.path_save_models): os.makedirs(self.path_save_models) self.path_save_plots = os.path.join(save_data_path, 'plots_' + time_and_date) if not os.path.exists(self.path_save_plots): os.makedirs(self.path_save_plots) self.path_save_metrics = os.path.join(save_data_path, 'metrics_' + time_and_date) if not os.path.exists(self.path_save_metrics): os.makedirs(self.path_save_metrics) # Make indexes for validation plots validation_plot_indexes = np.random.choice(range(len(self.validation_dataset_fid.dataset)), 49, replace=False) # Plot and save validation images used to plot generated images self.validation_images_to_plot, _, self.validation_masks = image_label_list_of_masks_collate_function( [self.validation_dataset_fid.dataset[index] for index in validation_plot_indexes]) torchvision.utils.save_image(misc.normalize_0_1_batch(self.validation_images_to_plot), os.path.join(self.path_save_plots, 'validation_images.png'), nrow=7) # Plot masks torchvision.utils.save_image(self.validation_masks[0], os.path.join(self.path_save_plots, 'validation_masks.png'), nrow=7) # Generate latents for validation self.validation_latents = torch.randn(49, self.latent_dimensions, dtype=torch.float32) # Log hyperparameter self.logger.hyperparameter['generator'] = str(self.generator) self.logger.hyperparameter['discriminator'] = str(self.discriminator) self.logger.hyperparameter['vgg16'] = str(self.vgg16) self.logger.hyperparameter['generator_optimizer'] = str(self.generator_optimizer) self.logger.hyperparameter['discriminator_optimizer'] = str(self.discriminator_optimizer) self.logger.hyperparameter['generator_loss'] = str(self.generator_loss) self.logger.hyperparameter['discriminator_loss'] = str(self.discriminator_loss) self.logger.hyperparameter['diversity_loss'] = str(self.diversity_loss) self.logger.hyperparameter['discriminator_loss'] = str(self.semantic_reconstruction_loss)