Пример #1
0
    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :return: A log that contains information about validation

        Note:
            The validation metrics in log must have the key 'val_metrics'.
        """
        self.generator.eval()
        self.discriminator.eval()

        total_val_loss = 0
        total_val_metrics = np.zeros(len(self.metrics))

        with torch.no_grad():
            for batch_idx, sample in enumerate(self.valid_data_loader):
                blurred = sample['blurred'].to(self.device)
                sharp = sample['sharp'].to(self.device)

                deblurred = self.generator(blurred)
                deblurred_discriminator_out = self.discriminator(deblurred)

                content_loss_lambda = self.config['others'][
                    'content_loss_lambda']
                kwargs = {
                    'deblurred_discriminator_out': deblurred_discriminator_out
                }
                adversarial_loss_g = self.adversarial_loss('G', **kwargs)
                content_loss_g = self.content_loss(deblurred,
                                                   sharp) * content_loss_lambda
                loss_g = adversarial_loss_g + content_loss_g

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.writer.add_scalar('adversarial_loss_g',
                                       adversarial_loss_g.item())
                self.writer.add_scalar('content_loss_g', content_loss_g.item())
                self.writer.add_scalar('loss_g', loss_g.item())
                total_val_loss += loss_g.item()

                total_val_metrics += self._eval_metrics(
                    denormalize(deblurred), denormalize(sharp))

        # add histogram of model parameters to the tensorboard
        for name, p in self.generator.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')

        return {
            'val_loss':
            total_val_loss / len(self.valid_data_loader),
            'val_metrics':
            (total_val_metrics / len(self.valid_data_loader)).tolist()
        }
Пример #2
0
def main(blurred_dir, deblurred_dir, resume):
    # load checkpoint
    checkpoint = torch.load(resume)
    config = checkpoint['config']

    # setup data_loader instances
    data_loader = CustomDataLoader(data_dir=blurred_dir)

    # build model architecture
    generator_class = getattr(module_arch, config['generator']['type'])
    generator = generator_class(**config['generator']['args'])

    # prepare model for deblurring
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    generator.to(device)

    generator.load_state_dict(checkpoint['generator'])

    generator.eval()

    # start to deblur
    with torch.no_grad():
        for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)):
            blurred = sample['blurred'].to(device)
            image_name = sample['image_name'][0]

            deblurred = generator(blurred)
            deblurred_img = to_pil_image(
                denormalize(deblurred).squeeze().cpu())

            deblurred_img.save(
                os.path.join(deblurred_dir, 'deblurred ' + image_name))
Пример #3
0
 def fix_image(img):
     if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
         img = img.unsqueeze(dim=1)
         # Normalize so spectrogram is easier to view.
         img = (img - img.mean()) / img.std()
     if img.shape[1] > 3:
         img = img[:, :3, :, :]
     if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
         img = (img + 1) / 2
     if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False):
         img = denormalize(img)
     return img
def main(blurred_image, resume):
    # load checkpoint
    checkpoint = torch.load(resume)
    config = checkpoint['config']

    # setup data_loader instances
    #data_loader = CustomDataLoader(data_dir=blurred_dir)

    # build model architecture
    generator_class = getattr(module_arch, config['generator']['type'])
    generator = generator_class(**config['generator']['args'])

    # prepare model for deblurring
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    generator.to(device)
    if config['n_gpu'] > 1:
        generator = torch.nn.DataParallel(generator)

    generator.load_state_dict(checkpoint['generator'])

    generator.eval()

    # start to deblur
    with torch.no_grad():
        blurred = Image.open(blurred_image).convert('RGB')
        h = blurred.size[1]
        w = blurred.size[0]
        new_h = h - h % 4 + 4 if h % 4 != 0 else h
        new_w = w - w % 4 + 4 if w % 4 != 0 else w
        blurred = transforms.Resize([new_h, new_w], Image.BICUBIC)(blurred)
        transform = transforms.Compose([
            transforms.ToTensor(),  # convert to tensor
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        blurred = transform(blurred)
        blurred.unsqueeze_(0)
        print(blurred.shape)
        blurred = blurred.to(device)
        deblurred = generator(blurred)
        deblurred_img = to_pil_image(denormalize(deblurred).squeeze().cpu())

        deblurred_img.save("./deblurred.png")
Пример #5
0
    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.

        Note:
            If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log

            The metrics in log must have the key 'metrics'.
        """
        # set models to train mode
        self.generator.train()
        self.discriminator.train()

        total_generator_loss = 0
        total_discriminator_loss = 0
        total_metrics = np.zeros(len(self.metrics))

        for batch_idx, sample in enumerate(self.data_loader):
            self.writer.set_step((epoch - 1) * len(self.data_loader) +
                                 batch_idx)

            # get data and send them to GPU
            blurred = sample['blurred'].to(self.device)
            sharp = sample['sharp'].to(self.device)

            # get G's output
            deblurred = self.generator(blurred)

            # denormalize
            with torch.no_grad():
                denormalized_blurred = denormalize(blurred)
                denormalized_sharp = denormalize(sharp)
                denormalized_deblurred = denormalize(deblurred)

            if batch_idx % 100 == 0:
                # save blurred, sharp and deblurred image
                self.writer.add_image('blurred',
                                      make_grid(denormalized_blurred.cpu()))
                self.writer.add_image('sharp',
                                      make_grid(denormalized_sharp.cpu()))
                self.writer.add_image('deblurred',
                                      make_grid(denormalized_deblurred.cpu()))

            # get D's output
            sharp_discriminator_out = self.discriminator(sharp)
            deblurred_discriminator_out = self.discriminator(deblurred)

            # set critic_updates
            if self.config['loss']['adversarial'] == 'wgan_gp_loss':
                critic_updates = 5
            else:
                critic_updates = 1

            # train discriminator
            discriminator_loss = 0
            for i in range(critic_updates):
                self.discriminator_optimizer.zero_grad()

                # train discriminator on real and fake
                if self.config['loss']['adversarial'] == 'wgan_gp_loss':
                    gp_lambda = self.config['others']['gp_lambda']
                    alpha = random.random()
                    interpolates = alpha * sharp + (1 - alpha) * deblurred
                    interpolates_discriminator_out = self.discriminator(
                        interpolates)
                    kwargs = {
                        'gp_lambda': gp_lambda,
                        'interpolates': interpolates,
                        'interpolates_discriminator_out':
                        interpolates_discriminator_out,
                        'sharp_discriminator_out': sharp_discriminator_out,
                        'deblurred_discriminator_out':
                        deblurred_discriminator_out
                    }
                    wgan_loss_d, gp_d = self.adversarial_loss('D', **kwargs)
                    discriminator_loss_per_update = wgan_loss_d + gp_d

                    self.writer.add_scalar('wgan_loss_d', wgan_loss_d.item())
                    self.writer.add_scalar('gp_d', gp_d.item())
                elif self.config['loss']['adversarial'] == 'gan_loss':
                    kwargs = {
                        'sharp_discriminator_out': sharp_discriminator_out,
                        'deblurred_discriminator_out':
                        deblurred_discriminator_out
                    }
                    gan_loss_d = self.adversarial_loss('D', **kwargs)
                    discriminator_loss_per_update = gan_loss_d

                    self.writer.add_scalar('gan_loss_d', gan_loss_d.item())
                else:
                    # add other loss if you like
                    raise NotImplementedError

                discriminator_loss_per_update.backward(retain_graph=True)
                self.discriminator_optimizer.step()
                discriminator_loss += discriminator_loss_per_update.item()

            discriminator_loss /= critic_updates
            self.writer.add_scalar('discriminator_loss', discriminator_loss)
            total_discriminator_loss += discriminator_loss

            # train generator
            self.generator_optimizer.zero_grad()

            content_loss_lambda = self.config['others']['content_loss_lambda']
            kwargs = {
                'deblurred_discriminator_out': deblurred_discriminator_out
            }
            adversarial_loss_g = self.adversarial_loss('G', **kwargs)
            content_loss_g = self.content_loss(deblurred,
                                               sharp) * content_loss_lambda
            generator_loss = adversarial_loss_g + content_loss_g

            self.writer.add_scalar('adversarial_loss_g',
                                   adversarial_loss_g.item())
            self.writer.add_scalar('content_loss_g', content_loss_g.item())
            self.writer.add_scalar('generator_loss', generator_loss.item())

            generator_loss.backward()
            self.generator_optimizer.step()
            total_generator_loss += generator_loss.item()

            # calculate the metrics
            total_metrics += self._eval_metrics(denormalized_deblurred,
                                                denormalized_sharp)

            if self.verbosity >= 2 and batch_idx % self.log_step == 0:
                self.logger.info(
                    'Train Epoch: {} [{}/{} ({:.0f}%)] generator_loss: {:.6f} discriminator_loss: {:.6f}'
                    .format(
                        epoch,
                        batch_idx * self.data_loader.batch_size,
                        self.data_loader.n_samples,
                        100.0 * batch_idx / len(self.data_loader),
                        generator_loss.item(
                        ),  # it's a tensor, so we call .item() method
                        discriminator_loss  # just a num
                    ))

        log = {
            'generator_loss': total_generator_loss / len(self.data_loader),
            'discriminator_loss':
            total_discriminator_loss / len(self.data_loader),
            'metrics': (total_metrics / len(self.data_loader)).tolist()
        }

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log = {**log, **val_log}

        self.generator_lr_scheduler.step()
        self.discriminator_lr_scheduler.step()

        return log
Пример #6
0
def main(resume):
    # load checkpoint
    checkpoint = torch.load(resume)
    config = checkpoint['config']

    # setup data_loader instances
    data_loader_class = getattr(module_data, config['data_loader']['type'])
    data_loader_config_args = {
        "data_dir": config['data_loader']['args']['data_dir'],
        'batch_size': 16,  # use large batch_size
        'shuffle': False,  # do not shuffle
        'validation_split': 0.0,  # do not split, just use the full dataset
        'num_workers': 16  # use large num_workers
    }
    data_loader = data_loader_class(**data_loader_config_args)

    # build model architecture
    generator_class = getattr(module_arch, config['generator']['type'])
    generator = generator_class(**config['generator']['args'])

    discriminator_class = getattr(module_arch, config['discriminator']['type'])
    discriminator = discriminator_class(**config['discriminator']['args'])

    # get function handles of loss and metrics
    loss_fn = {k: getattr(module_loss, v) for k, v in config['loss'].items()}
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]

    # prepare model for testing
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    if config['n_gpu'] > 1:
        generator = torch.nn.DataParallel(generator)
        discriminator = torch.nn.DataParallel(discriminator)

    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])

    generator.eval()
    discriminator.eval()

    total_loss = 0.0
    total_metrics = np.zeros(len(metric_fns))

    with torch.no_grad():
        for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)):
            blurred = sample['blurred'].to(device)
            sharp = sample['sharp'].to(device)

            deblurred = generator(blurred)
            deblurred_discriminator_out = discriminator(deblurred)

            denormalized_deblurred = denormalize(deblurred)
            denormalized_sharp = denormalize(sharp)

            # computing loss, metrics on test set
            content_loss_lambda = config['others']['content_loss_lambda']
            adversarial_loss_fn = loss_fn['adversarial']
            content_loss_fn = loss_fn['content']
            kwargs = {
                'deblurred_discriminator_out': deblurred_discriminator_out
            }
            loss = adversarial_loss_fn('G', **kwargs) + content_loss_fn(
                deblurred, sharp) * content_loss_lambda

            total_loss += loss.item()
            for i, metric in enumerate(metric_fns):
                total_metrics[i] += metric(denormalized_deblurred,
                                           denormalized_sharp)

    n_samples = len(data_loader)
    log = {'loss': total_loss / n_samples}
    log.update({
        met.__name__: total_metrics[i].item() / n_samples
        for i, met in enumerate(metric_fns)
    })
    print(log)