def run(self, trainer: Trainer): fid = self.fid_calculator(self.autoencoder, self.data_loader, self.dataset_path) synchronize() if get_rank() == 0: with get_current_reporter() as reporter: reporter.add_observation({"fid_score": fid}, "evaluation")
def run(self, trainer: Trainer): torch.cuda.empty_cache() try: for network in self.networks: network.eval() with torch.no_grad(): predictions = self.get_predictions() finally: for network in self.networks: network.train() display_images = torch.cat(predictions, dim=0) image_grid = torchvision.utils.make_grid( display_images, nrow=self.input_images.shape[0]) dest_file_name = os.path.join(self.image_dir, f"{trainer.updater.iteration:08d}.png") dest_image = make_image(image_grid) Image.fromarray(dest_image).save(dest_file_name) if self.log_to_logger: with get_current_reporter() as reporter: reporter.add_image({"image_plotter": dest_image}, trainer.updater.iteration) del display_images torch.cuda.empty_cache()
def __call__(self, batch): reporter = get_current_reporter() # since we only evaluate, we do not need to save the computational graph with torch.no_grad(): output = self.network(batch['images']) loss = F.nll_loss(output, batch['labels']) # calculate accuracy by taking most probable predictions predictions = output.argmax(dim=1, keepdim=True) accuracy = predictions.eq(batch['labels'].view_as(predictions)).sum().item() / len(batch['images']) reporter.add_observation({"test_loss": loss}, prefix='loss') reporter.add_observation({"accuracy": accuracy}, prefix='accuracy')
def run_training(self): reporter = get_current_reporter() for _ in self.get_progressbar(self.num_epochs, desc='epoch'): self.updater.reset() for __ in self.get_progressbar(self.iterations_per_epoch, leave=False, desc='iteration'): with reporter: self.updater.update() self.run_extensions() if self.stop_trigger(self): return
def update_core(self): reporter = get_current_reporter() image_batch = next(self.iterators['images']) image_batch = {k: v.to(self.device) for k, v in image_batch.items()} discriminator_observations = self.update_discriminator( image_batch['input_image'].clone().detach(), image_batch['output_image'].clone().detach(), ) reporter.add_observation(discriminator_observations, 'discriminator') generator_observations = self.update_generator( image_batch['input_image'].clone().detach(), image_batch['output_image'].clone().detach(), ) reporter.add_observation(generator_observations, 'generator')
def calculate_loss(self, input_images: torch.Tensor, reconstructed_images: torch.Tensor): reporter = get_current_reporter() mse_loss = F.mse_loss(input_images, reconstructed_images, reduction='none') loss = mse_loss.mean(dim=(1, 2, 3)).sum() reporter.add_observation({"reconstruction_loss": loss}, prefix='loss') if self.use_perceptual_loss: perceptual_loss = self.perceptual_loss(reconstructed_images, input_images).sum() reporter.add_observation({"perceptual_loss": perceptual_loss}, prefix='loss') loss += perceptual_loss loss.backward() reporter.add_observation({"autoencoder_loss": loss}, prefix='loss')
def __call__(self, batch): reporter = get_current_reporter() with torch.no_grad(): reconstructed_images = self.autoencoder(batch['input_image']) original_image = batch['output_image'] mse_loss = F.mse_loss(original_image, reconstructed_images, reduction='none') loss = mse_loss.mean(dim=(1, 2, 3)).sum() reporter.add_observation({"reconstruction_loss": loss}, prefix='evaluation') if self.use_perceptual_loss: perceptual_loss = self.perceptual_loss(reconstructed_images, original_image).sum() reporter.add_observation({"perceptual_loss": perceptual_loss}, prefix='evaluation') loss += perceptual_loss original_image = clamp_and_unnormalize(original_image) reconstructed_images = clamp_and_unnormalize(reconstructed_images) psnr = psnr_loss(reconstructed_images, original_image, max_val=1) ssim = ssim_loss(original_image, reconstructed_images, 5, reduction='mean') # since we get a loss, we need to calculate/reconstruct the original ssim value ssim = 1 - 2 * ssim reporter.add_observation({ "psnr": psnr, "ssim": ssim }, prefix='evaluation') reporter.add_observation({"autoencoder_loss": loss}, prefix='evaluation')
def update_core(self): # get the network we want to optimize net = self.networks['net'] # GradientApplier helps us save some boilerplate code with GradientApplier([net], self.optimizers.values()): # get the batch and transfer it to the training device batch = next(self.iterators['images']) batch = {k: v.to(self.device) for k, v in batch.items()} # perform forward pass through network prediction = net(batch['images']) # calculate loss loss = F.nll_loss(prediction, batch['labels']) # log the loss reporter = get_current_reporter() reporter.add_observation({"loss": loss}, prefix='loss') # perform backward pass for later weight update loss.backward()
def update_generator(self, input_images: torch.Tensor, output_images: torch.Tensor) -> dict: autoencoder = self.get_autoencoder() discriminator = self.get_discriminator() reporter = get_current_reporter() autoencoder_optimizer = self.optimizers['main'] log_data = {} with UpdateDisabler(autoencoder.decoder), GradientApplier([autoencoder], [autoencoder_optimizer]): reconstructed_images = autoencoder(input_images) mse_loss = F.mse_loss(output_images, reconstructed_images, reduction='none') loss = mse_loss.mean(dim=(1, 2, 3)).sum() reporter.add_observation({"reconstruction_loss": loss}, prefix='loss') if self.use_perceptual_loss: perceptual_loss = self.perceptual_loss(reconstructed_images, output_images).sum() loss += perceptual_loss reporter.add_observation( {"autoencoder_loss": loss, "perceptual_loss": perceptual_loss}, prefix='loss' ) discriminator_prediction = discriminator(reconstructed_images) discriminator_loss = F.softplus(-discriminator_prediction).mean() loss += discriminator_loss loss.backward() log_data.update({ "loss": loss, "discriminator_loss": discriminator_loss, }) torch.cuda.empty_cache() return log_data
def log_lr(self, scheduler_name: str, scheduler: _LRScheduler): for i, param_group in enumerate(scheduler.optimizer.param_groups): lr = param_group['lr'] suffix = f"/{i}" if len(scheduler.optimizer.param_groups) > 1 else "" with get_current_reporter() as reporter: reporter.add_observation({f"lr/{scheduler_name}{suffix}": lr}, prefix='metrics')