Exemplo n.º 1
0
 def _on_forward_pass_batch(self, batch, output, train):
     if train:
         sparsity = compute_sparsity(output)
         self.online['sparsity'].update(sparsity.cpu())
         self.online['l1_norm'].update(output.abs().mean(dim=0).cpu())
         if self.data_loader.has_labels:
             # supervised
             input, labels = batch
             self.online['clusters'].update(output, labels)
     super()._on_forward_pass_batch(batch, output, train)
Exemplo n.º 2
0
 def solve(self, A, b, M_inv=None):
     v_solution, dv_norm, iteration = basis_pursuit_admm(
         A=A, b=b, lambd=self.lambd,
         M_inv=M_inv, tol=self.tol,
         max_iters=self.max_iters,
         return_stats=True)
     if self.save_stats:
         iteration = torch.tensor(iteration + 1, dtype=torch.float32)
         self.online['dv_norm'].update(dv_norm.cpu())
         self.online['iterations'].update(iteration)
         b_restored = v_solution.matmul(A.t())
         self.online['psnr'].update(peak_to_signal_noise_ratio(
             b, b_restored).cpu())
         self.online['sparsity'].update(compute_sparsity(v_solution).cpu())
     return v_solution
Exemplo n.º 3
0
def dataset_mean(data_loader: DataLoader, verbose=True):
    # L1 sparsity: ||x||_1 / size(x)
    #
    # MNIST:         0.131
    # FashionMNIST:  0.286
    # CIFAR10:       0.473
    # CIFAR100:      0.478
    loader = data_loader.eval()
    sparsity_online = MeanOnline()
    for batch in tqdm(
            loader,
            desc=f"Computing {data_loader.dataset_cls.__name__} mean",
            disable=not verbose,
            leave=False):
        input = input_from_batch(batch)
        input = input.flatten(start_dim=1)
        sparsity = compute_sparsity(input)
        sparsity_online.update(sparsity)
    return sparsity_online.get_mean()
Exemplo n.º 4
0
    def full_forward_pass(self, train=True):
        if not train:
            return None
        assert isinstance(self.criterion,
                          nn.MSELoss), "BMP can work only with MSE loss"

        mode_saved = self.model.training
        self.model.train(False)
        use_cuda = torch.cuda.is_available()

        loss_online = MeanOnline()
        psnr_online = MeanOnline()
        sparsity_online = MeanOnline()
        with torch.no_grad():
            for batch in self.data_loader.eval(
                    description="Full forward pass (eval)"):
                if use_cuda:
                    batch = batch_to_cuda(batch)
                input = input_from_batch(batch)
                loss = []
                psnr = []
                sparsity = []
                for bmp_param in self.bmp_params:
                    outputs = self.model(input, bmp_param)
                    latent, reconstructed = outputs
                    loss_lambd = self._get_loss(batch, outputs)
                    psnr_lmdb = peak_to_signal_noise_ratio(
                        input, reconstructed)
                    sparsity_lambd = compute_sparsity(latent)
                    loss.append(loss_lambd.cpu())
                    psnr.append(psnr_lmdb.cpu())
                    sparsity.append(sparsity_lambd.cpu())

                loss_online.update(torch.stack(loss))
                psnr_online.update(torch.stack(psnr))
                sparsity_online.update(torch.stack(sparsity))

        loss = loss_online.get_mean()
        self.monitor.viz.line(Y=loss,
                              X=self.bmp_params,
                              win='Loss',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Loss',
                                        title='Loss'))

        psnr = psnr_online.get_mean()
        self.monitor.viz.line(Y=psnr,
                              X=self.bmp_params,
                              win='PSNR',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Peak signal-to-noise ratio',
                                        title='PSNR'))

        sparsity = sparsity_online.get_mean()
        self.monitor.viz.line(Y=sparsity,
                              X=self.bmp_params,
                              win='Sparsity',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='sparsity',
                                        title='L1 output sparsity'))

        self.monitor.viz.close(win='Accuracy')
        self.model.train(mode_saved)

        return loss