Exemplo n.º 1
0
 def _get_loss(self, batch, output):
     assert isinstance(self.model, LISTA)
     input = input_from_batch(batch)
     latent, reconstructed = output
     latent_best, _ = self.model_reference(input)
     loss = self.criterion(latent, latent_best)
     return loss
Exemplo n.º 2
0
 def _get_loss(self, batch, output):
     input = input_from_batch(batch)
     latent, reconstructed = output
     if isinstance(self.criterion, LossPenalty):
         loss = self.criterion(reconstructed, input, latent)
     else:
         loss = self.criterion(reconstructed, input)
     return loss
Exemplo n.º 3
0
 def _plot_autoencoder(self, batch, reconstructed, mode='train'):
     input = input_from_batch(batch)
     thr_lowest = self.reconstruct.threshold_optimal
     rec_binary = (reconstructed >= thr_lowest).float()
     self.monitor.plot_autoencoder_binary(input,
                                          reconstructed,
                                          rec_binary,
                                          mode=mode)
Exemplo n.º 4
0
 def _on_forward_pass_batch(self, batch, output, train):
     input = input_from_batch(batch)
     latent, reconstructed = output
     if isinstance(self.criterion, nn.BCEWithLogitsLoss):
         reconstructed = reconstructed.sigmoid()
     psnr = peak_to_signal_noise_ratio(input, reconstructed)
     fold = 'train' if train else 'test'
     if torch.isfinite(psnr):
         self.online[f'psnr-{fold}'].update(psnr.cpu())
     super()._on_forward_pass_batch(batch, latent, train)
Exemplo n.º 5
0
def dataset_mean(data_loader: DataLoader, verbose=True):
    # L1 sparsity: ||x||_1 / size(x)
    #
    # MNIST:         0.131
    # FashionMNIST:  0.286
    # CIFAR10:       0.473
    # CIFAR100:      0.478
    loader = data_loader.eval()
    sparsity_online = MeanOnline()
    for batch in tqdm(
            loader,
            desc=f"Computing {data_loader.dataset_cls.__name__} mean",
            disable=not verbose,
            leave=False):
        input = input_from_batch(batch)
        input = input.flatten(start_dim=1)
        sparsity = compute_sparsity(input)
        sparsity_online.update(sparsity)
    return sparsity_online.get_mean()
Exemplo n.º 6
0
    def _on_forward_pass_batch(self, batch, output, train):
        input = input_from_batch(batch)
        latent, reconstructed = output
        if isinstance(self.criterion, nn.BCEWithLogitsLoss):
            reconstructed = reconstructed.sigmoid()

        if self.data_loader.normalize_inverse is not None:
            warnings.warn("'normalize_inverse' is not None. Applying it "
                          "to count reconstructed pixels")
            input = self.data_loader.normalize_inverse(input)
            reconstructed = self.data_loader.normalize_inverse(reconstructed)

        pix_miss, correct = self.reconstruct.compute(input, reconstructed)
        if train:
            # update only for train
            # pix_miss is of shape (B, THR)
            self.online['pixel-error'].update(pix_miss.cpu())

        fold = 'train' if train else 'test'
        self.online[f'reconstruct-exact-{fold}'].update(correct.cpu())

        super()._on_forward_pass_batch(batch, output, train)
Exemplo n.º 7
0
def dataset_mean_std(dataset_cls: type):
    """
    Estimates dataset mean and std.

    Parameters
    ----------
    dataset_cls : type
        A dataset class.

    Returns
    -------
    mean, std : (C, H, W) torch.Tensor
        Channel- and pixel-wise dataset mean and std, estimated over all
        samples.
    """
    mean_std_file = (DATA_DIR / "mean_std" /
                     dataset_cls.__name__).with_suffix('.pt')
    if not mean_std_file.exists():
        dataset = dataset_cls(DATA_DIR,
                              train=True,
                              download=True,
                              transform=ToTensor())
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=False)
        var_online = VarianceOnlineBatch()
        for batch in tqdm(
                loader,
                desc=f"{dataset_cls.__name__}: running online mean, std"):
            input = input_from_batch(batch)
            var_online.update(input)
        mean, std = var_online.get_mean_std()
        mean_std_file.parent.mkdir(exist_ok=True, parents=True)
        with open(mean_std_file, 'wb') as f:
            torch.save((mean, std), f)
    with open(mean_std_file, 'rb') as f:
        mean, std = torch.load(f)
    return mean, std
Exemplo n.º 8
0
 def _plot_autoencoder(self, batch, reconstructed, mode='train'):
     input = input_from_batch(batch)
     self.monitor.plot_autoencoder(input, reconstructed, mode=mode)
Exemplo n.º 9
0
 def test_input_from_batch(self):
     x, y = torch.rand(5, 1, 28, 28), torch.arange(5)
     self.assertIs(input_from_batch(x), x)
     self.assertIs(input_from_batch((x, y)), x)
Exemplo n.º 10
0
 def _forward(self, batch):
     input = input_from_batch(batch)
     return self.model(input)
Exemplo n.º 11
0
    def full_forward_pass(self, train=True):
        if not train:
            return None
        assert isinstance(self.criterion,
                          nn.MSELoss), "BMP can work only with MSE loss"

        mode_saved = self.model.training
        self.model.train(False)
        use_cuda = torch.cuda.is_available()

        loss_online = MeanOnline()
        psnr_online = MeanOnline()
        sparsity_online = MeanOnline()
        with torch.no_grad():
            for batch in self.data_loader.eval(
                    description="Full forward pass (eval)"):
                if use_cuda:
                    batch = batch_to_cuda(batch)
                input = input_from_batch(batch)
                loss = []
                psnr = []
                sparsity = []
                for bmp_param in self.bmp_params:
                    outputs = self.model(input, bmp_param)
                    latent, reconstructed = outputs
                    loss_lambd = self._get_loss(batch, outputs)
                    psnr_lmdb = peak_to_signal_noise_ratio(
                        input, reconstructed)
                    sparsity_lambd = compute_sparsity(latent)
                    loss.append(loss_lambd.cpu())
                    psnr.append(psnr_lmdb.cpu())
                    sparsity.append(sparsity_lambd.cpu())

                loss_online.update(torch.stack(loss))
                psnr_online.update(torch.stack(psnr))
                sparsity_online.update(torch.stack(sparsity))

        loss = loss_online.get_mean()
        self.monitor.viz.line(Y=loss,
                              X=self.bmp_params,
                              win='Loss',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Loss',
                                        title='Loss'))

        psnr = psnr_online.get_mean()
        self.monitor.viz.line(Y=psnr,
                              X=self.bmp_params,
                              win='PSNR',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Peak signal-to-noise ratio',
                                        title='PSNR'))

        sparsity = sparsity_online.get_mean()
        self.monitor.viz.line(Y=sparsity,
                              X=self.bmp_params,
                              win='Sparsity',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='sparsity',
                                        title='L1 output sparsity'))

        self.monitor.viz.close(win='Accuracy')
        self.model.train(mode_saved)

        return loss