def _on_forward_pass_batch(self, batch, output, train): if train: sparsity = compute_sparsity(output) self.online['sparsity'].update(sparsity.cpu()) self.online['l1_norm'].update(output.abs().mean(dim=0).cpu()) if self.data_loader.has_labels: # supervised input, labels = batch self.online['clusters'].update(output, labels) super()._on_forward_pass_batch(batch, output, train)
def solve(self, A, b, M_inv=None): v_solution, dv_norm, iteration = basis_pursuit_admm( A=A, b=b, lambd=self.lambd, M_inv=M_inv, tol=self.tol, max_iters=self.max_iters, return_stats=True) if self.save_stats: iteration = torch.tensor(iteration + 1, dtype=torch.float32) self.online['dv_norm'].update(dv_norm.cpu()) self.online['iterations'].update(iteration) b_restored = v_solution.matmul(A.t()) self.online['psnr'].update(peak_to_signal_noise_ratio( b, b_restored).cpu()) self.online['sparsity'].update(compute_sparsity(v_solution).cpu()) return v_solution
def dataset_mean(data_loader: DataLoader, verbose=True): # L1 sparsity: ||x||_1 / size(x) # # MNIST: 0.131 # FashionMNIST: 0.286 # CIFAR10: 0.473 # CIFAR100: 0.478 loader = data_loader.eval() sparsity_online = MeanOnline() for batch in tqdm( loader, desc=f"Computing {data_loader.dataset_cls.__name__} mean", disable=not verbose, leave=False): input = input_from_batch(batch) input = input.flatten(start_dim=1) sparsity = compute_sparsity(input) sparsity_online.update(sparsity) return sparsity_online.get_mean()
def full_forward_pass(self, train=True): if not train: return None assert isinstance(self.criterion, nn.MSELoss), "BMP can work only with MSE loss" mode_saved = self.model.training self.model.train(False) use_cuda = torch.cuda.is_available() loss_online = MeanOnline() psnr_online = MeanOnline() sparsity_online = MeanOnline() with torch.no_grad(): for batch in self.data_loader.eval( description="Full forward pass (eval)"): if use_cuda: batch = batch_to_cuda(batch) input = input_from_batch(batch) loss = [] psnr = [] sparsity = [] for bmp_param in self.bmp_params: outputs = self.model(input, bmp_param) latent, reconstructed = outputs loss_lambd = self._get_loss(batch, outputs) psnr_lmdb = peak_to_signal_noise_ratio( input, reconstructed) sparsity_lambd = compute_sparsity(latent) loss.append(loss_lambd.cpu()) psnr.append(psnr_lmdb.cpu()) sparsity.append(sparsity_lambd.cpu()) loss_online.update(torch.stack(loss)) psnr_online.update(torch.stack(psnr)) sparsity_online.update(torch.stack(sparsity)) loss = loss_online.get_mean() self.monitor.viz.line(Y=loss, X=self.bmp_params, win='Loss', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='Loss', title='Loss')) psnr = psnr_online.get_mean() self.monitor.viz.line(Y=psnr, X=self.bmp_params, win='PSNR', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='Peak signal-to-noise ratio', title='PSNR')) sparsity = sparsity_online.get_mean() self.monitor.viz.line(Y=sparsity, X=self.bmp_params, win='Sparsity', opts=dict(xlabel=f'BMP {self.param_name}', ylabel='sparsity', title='L1 output sparsity')) self.monitor.viz.close(win='Accuracy') self.model.train(mode_saved) return loss