Пример #1
0
 def test_peak_to_signal_noise_ratio(self):
     tensor = torch.rand(3, 20)
     psnr_inf = peak_to_signal_noise_ratio(tensor, tensor)
     # PSNR is a quality measure between the corrupted and original signals.
     # If the signals match precisely, it's a lossless encoder, and psnr
     # evaluates to +inf.
     self.assertTrue(torch.isinf(psnr_inf))
     tensor = torch.ones(3, 20)
     psnr_nan = peak_to_signal_noise_ratio(tensor, torch.rand_like(tensor))
     self.assertTrue(torch.isnan(psnr_nan))
Пример #2
0
 def test_exponential_moving_average_psnr(self):
     set_seed(1)
     noise = torch.rand(100)
     smoothed = exponential_moving_average(noise, window=3)
     psnr = peak_to_signal_noise_ratio(noise, smoothed)
     self.assertGreaterEqual(psnr, 15.7)
     self.assertEqual(smoothed.shape, noise.shape)
     self.assertLess(smoothed.std(), noise.std())
Пример #3
0
 def _on_forward_pass_batch(self, batch, output, train):
     input = input_from_batch(batch)
     latent, reconstructed = output
     if isinstance(self.criterion, nn.BCEWithLogitsLoss):
         reconstructed = reconstructed.sigmoid()
     psnr = peak_to_signal_noise_ratio(input, reconstructed)
     fold = 'train' if train else 'test'
     if torch.isfinite(psnr):
         self.online[f'psnr-{fold}'].update(psnr.cpu())
     super()._on_forward_pass_batch(batch, latent, train)
Пример #4
0
 def solve(self, A, b, M_inv=None):
     v_solution, dv_norm, iteration = basis_pursuit_admm(
         A=A, b=b, lambd=self.lambd,
         M_inv=M_inv, tol=self.tol,
         max_iters=self.max_iters,
         return_stats=True)
     if self.save_stats:
         iteration = torch.tensor(iteration + 1, dtype=torch.float32)
         self.online['dv_norm'].update(dv_norm.cpu())
         self.online['iterations'].update(iteration)
         b_restored = v_solution.matmul(A.t())
         self.online['psnr'].update(peak_to_signal_noise_ratio(
             b, b_restored).cpu())
         self.online['sparsity'].update(compute_sparsity(v_solution).cpu())
     return v_solution
Пример #5
0
    def full_forward_pass(self, train=True):
        if not train:
            return None
        assert isinstance(self.criterion,
                          nn.MSELoss), "BMP can work only with MSE loss"

        mode_saved = self.model.training
        self.model.train(False)
        use_cuda = torch.cuda.is_available()

        loss_online = MeanOnline()
        psnr_online = MeanOnline()
        sparsity_online = MeanOnline()
        with torch.no_grad():
            for batch in self.data_loader.eval(
                    description="Full forward pass (eval)"):
                if use_cuda:
                    batch = batch_to_cuda(batch)
                input = input_from_batch(batch)
                loss = []
                psnr = []
                sparsity = []
                for bmp_param in self.bmp_params:
                    outputs = self.model(input, bmp_param)
                    latent, reconstructed = outputs
                    loss_lambd = self._get_loss(batch, outputs)
                    psnr_lmdb = peak_to_signal_noise_ratio(
                        input, reconstructed)
                    sparsity_lambd = compute_sparsity(latent)
                    loss.append(loss_lambd.cpu())
                    psnr.append(psnr_lmdb.cpu())
                    sparsity.append(sparsity_lambd.cpu())

                loss_online.update(torch.stack(loss))
                psnr_online.update(torch.stack(psnr))
                sparsity_online.update(torch.stack(sparsity))

        loss = loss_online.get_mean()
        self.monitor.viz.line(Y=loss,
                              X=self.bmp_params,
                              win='Loss',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Loss',
                                        title='Loss'))

        psnr = psnr_online.get_mean()
        self.monitor.viz.line(Y=psnr,
                              X=self.bmp_params,
                              win='PSNR',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='Peak signal-to-noise ratio',
                                        title='PSNR'))

        sparsity = sparsity_online.get_mean()
        self.monitor.viz.line(Y=sparsity,
                              X=self.bmp_params,
                              win='Sparsity',
                              opts=dict(xlabel=f'BMP {self.param_name}',
                                        ylabel='sparsity',
                                        title='L1 output sparsity'))

        self.monitor.viz.close(win='Accuracy')
        self.model.train(mode_saved)

        return loss