def test_peak_to_signal_noise_ratio(self): tensor = torch.rand(3, 20) psnr_inf = peak_to_signal_noise_ratio(tensor, tensor) # PSNR is a quality measure between the corrupted and original signals. # If the signals match precisely, it's a lossless encoder, and psnr # evaluates to +inf. self.assertTrue(torch.isinf(psnr_inf)) tensor = torch.ones(3, 20) psnr_nan = peak_to_signal_noise_ratio(tensor, torch.rand_like(tensor)) self.assertTrue(torch.isnan(psnr_nan))
def test_exponential_moving_average_psnr(self): set_seed(1) noise = torch.rand(100) smoothed = exponential_moving_average(noise, window=3) psnr = peak_to_signal_noise_ratio(noise, smoothed) self.assertGreaterEqual(psnr, 15.7) self.assertEqual(smoothed.shape, noise.shape) self.assertLess(smoothed.std(), noise.std())
def _on_forward_pass_batch(self, batch, output, train): input = input_from_batch(batch) latent, reconstructed = output if isinstance(self.criterion, nn.BCEWithLogitsLoss): reconstructed = reconstructed.sigmoid() psnr = peak_to_signal_noise_ratio(input, reconstructed) fold = 'train' if train else 'test' if torch.isfinite(psnr): self.online[f'psnr-{fold}'].update(psnr.cpu()) super()._on_forward_pass_batch(batch, latent, train)
def solve(self, A, b, M_inv=None): v_solution, dv_norm, iteration = basis_pursuit_admm( A=A, b=b, lambd=self.lambd, M_inv=M_inv, tol=self.tol, max_iters=self.max_iters, return_stats=True) if self.save_stats: iteration = torch.tensor(iteration + 1, dtype=torch.float32) self.online['dv_norm'].update(dv_norm.cpu()) self.online['iterations'].update(iteration) b_restored = v_solution.matmul(A.t()) self.online['psnr'].update(peak_to_signal_noise_ratio( b, b_restored).cpu()) self.online['sparsity'].update(compute_sparsity(v_solution).cpu()) return v_solution
def full_forward_pass(self, train=True): if not train: return None assert isinstance(self.criterion, nn.MSELoss), "BMP can work only with MSE loss" mode_saved = self.model.training self.model.train(False) use_cuda = torch.cuda.is_available() loss_online = MeanOnline() psnr_online = MeanOnline() sparsity_online = MeanOnline() with torch.no_grad(): for batch in self.data_loader.eval( description="Full forward pass (eval)"): if use_cuda: batch = batch_to_cuda(batch) input = input_from_batch(batch) loss = [] psnr = [] sparsity = [] for bmp_param in self.bmp_params: outputs = self.model(input, bmp_param) latent, reconstructed = outputs loss_lambd = self._get_loss(batch, outputs) psnr_lmdb = peak_to_signal_noise_ratio( input, reconstructed) sparsity_lambd = compute_sparsity(latent) loss.append(loss_lambd.cpu()) psnr.append(psnr_lmdb.cpu()) sparsity.append(sparsity_lambd.cpu()) loss_online.update(torch.stack(loss)) psnr_online.update(torch.stack(psnr)) sparsity_online.update(torch.stack(sparsity)) loss = loss_online.get_mean() self.monitor.viz.line(Y=loss, X=self.bmp_params, win='Loss', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='Loss', title='Loss')) psnr = psnr_online.get_mean() self.monitor.viz.line(Y=psnr, X=self.bmp_params, win='PSNR', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='Peak signal-to-noise ratio', title='PSNR')) sparsity = sparsity_online.get_mean() self.monitor.viz.line(Y=sparsity, X=self.bmp_params, win='Sparsity', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='sparsity', title='L1 output sparsity')) self.monitor.viz.close(win='Accuracy') self.model.train(mode_saved) return loss