Пример #1
0
 def run(self, trainer: Trainer):
     fid = self.fid_calculator(self.autoencoder, self.data_loader,
                               self.dataset_path)
     synchronize()
     if get_rank() == 0:
         with get_current_reporter() as reporter:
             reporter.add_observation({"fid_score": fid}, "evaluation")
Пример #2
0
    def run(self, trainer: Trainer):
        torch.cuda.empty_cache()
        try:
            for network in self.networks:
                network.eval()
            with torch.no_grad():
                predictions = self.get_predictions()
        finally:
            for network in self.networks:
                network.train()

        display_images = torch.cat(predictions, dim=0)

        image_grid = torchvision.utils.make_grid(
            display_images, nrow=self.input_images.shape[0])

        dest_file_name = os.path.join(self.image_dir,
                                      f"{trainer.updater.iteration:08d}.png")
        dest_image = make_image(image_grid)
        Image.fromarray(dest_image).save(dest_file_name)

        if self.log_to_logger:
            with get_current_reporter() as reporter:
                reporter.add_image({"image_plotter": dest_image},
                                   trainer.updater.iteration)

        del display_images
        torch.cuda.empty_cache()
Пример #3
0
    def __call__(self, batch):
        reporter = get_current_reporter()

        # since we only evaluate, we do not need to save the computational graph
        with torch.no_grad():
            output = self.network(batch['images'])

            loss = F.nll_loss(output, batch['labels'])
            # calculate accuracy by taking most probable predictions
            predictions = output.argmax(dim=1, keepdim=True)
            accuracy = predictions.eq(batch['labels'].view_as(predictions)).sum().item() / len(batch['images'])

            reporter.add_observation({"test_loss": loss}, prefix='loss')
            reporter.add_observation({"accuracy": accuracy}, prefix='accuracy')
Пример #4
0
    def run_training(self):
        reporter = get_current_reporter()

        for _ in self.get_progressbar(self.num_epochs, desc='epoch'):
            self.updater.reset()
            for __ in self.get_progressbar(self.iterations_per_epoch,
                                           leave=False,
                                           desc='iteration'):
                with reporter:
                    self.updater.update()

                self.run_extensions()

                if self.stop_trigger(self):
                    return
Пример #5
0
    def update_core(self):
        reporter = get_current_reporter()
        image_batch = next(self.iterators['images'])
        image_batch = {k: v.to(self.device) for k, v in image_batch.items()}

        discriminator_observations = self.update_discriminator(
            image_batch['input_image'].clone().detach(),
            image_batch['output_image'].clone().detach(),
        )
        reporter.add_observation(discriminator_observations, 'discriminator')

        generator_observations = self.update_generator(
            image_batch['input_image'].clone().detach(),
            image_batch['output_image'].clone().detach(),
        )
        reporter.add_observation(generator_observations, 'generator')
    def calculate_loss(self, input_images: torch.Tensor,
                       reconstructed_images: torch.Tensor):
        reporter = get_current_reporter()

        mse_loss = F.mse_loss(input_images,
                              reconstructed_images,
                              reduction='none')
        loss = mse_loss.mean(dim=(1, 2, 3)).sum()
        reporter.add_observation({"reconstruction_loss": loss}, prefix='loss')
        if self.use_perceptual_loss:
            perceptual_loss = self.perceptual_loss(reconstructed_images,
                                                   input_images).sum()
            reporter.add_observation({"perceptual_loss": perceptual_loss},
                                     prefix='loss')
            loss += perceptual_loss

        loss.backward()
        reporter.add_observation({"autoencoder_loss": loss}, prefix='loss')
    def __call__(self, batch):
        reporter = get_current_reporter()

        with torch.no_grad():
            reconstructed_images = self.autoencoder(batch['input_image'])
            original_image = batch['output_image']

            mse_loss = F.mse_loss(original_image,
                                  reconstructed_images,
                                  reduction='none')
            loss = mse_loss.mean(dim=(1, 2, 3)).sum()
            reporter.add_observation({"reconstruction_loss": loss},
                                     prefix='evaluation')
            if self.use_perceptual_loss:
                perceptual_loss = self.perceptual_loss(reconstructed_images,
                                                       original_image).sum()
                reporter.add_observation({"perceptual_loss": perceptual_loss},
                                         prefix='evaluation')
                loss += perceptual_loss

            original_image = clamp_and_unnormalize(original_image)
            reconstructed_images = clamp_and_unnormalize(reconstructed_images)
            psnr = psnr_loss(reconstructed_images, original_image, max_val=1)

            ssim = ssim_loss(original_image,
                             reconstructed_images,
                             5,
                             reduction='mean')
            # since we get a loss, we need to calculate/reconstruct the original ssim value
            ssim = 1 - 2 * ssim

            reporter.add_observation({
                "psnr": psnr,
                "ssim": ssim
            },
                                     prefix='evaluation')

        reporter.add_observation({"autoencoder_loss": loss},
                                 prefix='evaluation')
Пример #8
0
    def update_core(self):
        # get the network we want to optimize
        net = self.networks['net']

        # GradientApplier helps us save some boilerplate code
        with GradientApplier([net], self.optimizers.values()):
            # get the batch and transfer it to the training device
            batch = next(self.iterators['images'])
            batch = {k: v.to(self.device) for k, v in batch.items()}

            # perform forward pass through network
            prediction = net(batch['images'])

            # calculate loss
            loss = F.nll_loss(prediction, batch['labels'])

            # log the loss
            reporter = get_current_reporter()
            reporter.add_observation({"loss": loss}, prefix='loss')

            # perform backward pass for later weight update
            loss.backward()
Пример #9
0
    def update_generator(self, input_images: torch.Tensor, output_images: torch.Tensor) -> dict:
        autoencoder = self.get_autoencoder()
        discriminator = self.get_discriminator()

        reporter = get_current_reporter()

        autoencoder_optimizer = self.optimizers['main']
        log_data = {}

        with UpdateDisabler(autoencoder.decoder), GradientApplier([autoencoder], [autoencoder_optimizer]):
            reconstructed_images = autoencoder(input_images)

            mse_loss = F.mse_loss(output_images, reconstructed_images, reduction='none')
            loss = mse_loss.mean(dim=(1, 2, 3)).sum()
            reporter.add_observation({"reconstruction_loss": loss}, prefix='loss')
            if self.use_perceptual_loss:
                perceptual_loss = self.perceptual_loss(reconstructed_images, output_images).sum()
                loss += perceptual_loss
                reporter.add_observation(
                    {"autoencoder_loss": loss, "perceptual_loss": perceptual_loss},
                    prefix='loss'
                )

            discriminator_prediction = discriminator(reconstructed_images)
            discriminator_loss = F.softplus(-discriminator_prediction).mean()

            loss += discriminator_loss
            loss.backward()

        log_data.update({
            "loss": loss,
            "discriminator_loss": discriminator_loss,
        })
        torch.cuda.empty_cache()

        return log_data
Пример #10
0
 def log_lr(self, scheduler_name: str, scheduler: _LRScheduler):
     for i, param_group in enumerate(scheduler.optimizer.param_groups):
         lr = param_group['lr']
         suffix = f"/{i}" if len(scheduler.optimizer.param_groups) > 1 else ""
         with get_current_reporter() as reporter:
             reporter.add_observation({f"lr/{scheduler_name}{suffix}": lr}, prefix='metrics')