def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True): ''' batched conjugate gradient descent. assumes the first index is batch size ''' # explicitly remove r from the computational graph r = b.new_zeros(b.shape, requires_grad=False) # the first calc of the residual may not be necessary in some cases... if l2lam > 0: r = b - (Aop_fun(x) + l2lam * x) else: r = b - Aop_fun(x) p = r rsnot = ip_batch(r) rsold = rsnot rsnew = rsnot eps_squared = eps**2 reshape = (-1, ) + (1, ) * (len(x.shape) - 1) num_iter = 0 for i in range(max_iter): if verbose: print('{i}: {rsnew}'.format(i=i, rsnew=utils.itemize( torch.sqrt(rsnew)))) if rsnew.max() < eps: break if l2lam > 0: Ap = Aop_fun(p) + l2lam * p else: Ap = Aop_fun(p) pAp = dot_batch(p, Ap) #print(utils.itemize(pAp)) alpha = (rsold / pAp).reshape(reshape) x = x + alpha * p r = r - alpha * Ap rsnew = ip_batch(r) beta = (rsnew / rsold).reshape(reshape) rsold = rsnew p = beta * p + r num_iter += 1 if verbose: print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew))) return x, num_iter
def forward(self, y): eps = opt.ip_batch(self.A.maps.shape[1] * self.A.mask.sum( (1, 2))).sqrt() * self.hparams.stdev x = self.A.adjoint(y) z = self.A(x) z_old = z u = z.new_zeros(z.shape) x.requires_grad = False z.requires_grad = False z_old.requires_grad = False u.requires_grad = False self.num_cg = np.zeros(( self.hparams.num_unrolls, self.hparams.num_admm, )) for i in range(self.hparams.num_unrolls): r = self.denoiser(x) for j in range(self.hparams.num_admm): rhs = self.l2lam * self.A.adjoint(z - u) + r fun = lambda xx: self.l2lam * self.A.normal(xx) + xx cg_op = ConjGrad(rhs, fun, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) n_cg = cg_op.num_cg self.num_cg[i, j] = n_cg Ax_plus_u = self.A(x) + u z_old = z z = y + opt.l2ball_proj_batch(Ax_plus_u - y, eps) u = Ax_plus_u - z # check ADMM convergence Ax = self.A(x) r_norm = opt.ip_batch(Ax - z).sqrt() s_norm = opt.ip_batch(self.l2lam * self.A.adjoint(z - z_old)).sqrt() if (r_norm + s_norm).max() < 1E-2: if self.debug_level > 0: tqdm.tqdm.write('stopping early, a={}'.format(a)) break return x
def training_step(self, batch, batch_nb): idx, data = batch idx = utils.itemize(idx) imgs = data['imgs'] inp = data['out'] self.batch(data) x_hat = self.forward(inp) try: num_cg = self.get_metadata()['num_cg'] except KeyError: num_cg = 0 _b = inp.shape[0] if _b == 1 and idx == 0: _idx = 0 elif _b > 1 and 0 in idx: _idx = idx.index(0) else: _idx = None if _idx is not None: with torch.no_grad(): if self.x_adj is None: x_adj = self.A.adjoint(inp) else: x_adj = self.x_adj _x_hat = utils.t2n(x_hat[_idx, ...]) _x_gt = utils.t2n(imgs[_idx, ...]) _x_adj = utils.t2n(x_adj[_idx, ...]) myim = torch.tensor( np.stack((np.abs(_x_hat), np.angle(_x_hat)), axis=0))[:, None, ...] grid = make_grid(myim, scale_each=True, normalize=True, nrow=8, pad_value=10) self.logger.experiment.add_image('2_train_prediction', grid, self.current_epoch) if self.current_epoch == 0: myim = torch.tensor( np.stack((np.abs(_x_gt), np.angle(_x_gt)), axis=0))[:, None, ...] grid = make_grid(myim, scale_each=True, normalize=True, nrow=8, pad_value=10) self.logger.experiment.add_image('1_ground_truth', grid, 0) myim = torch.tensor( np.stack((np.abs(_x_adj), np.angle(_x_adj)), axis=0))[:, None, ...] grid = make_grid(myim, scale_each=True, normalize=True, nrow=8, pad_value=10) self.logger.experiment.add_image('0_input', grid, 0) if self.self_supervised: pred = self.A.forward(x_hat) gt = inp else: pred = x_hat gt = imgs loss = self.loss_fun(pred, gt) _loss = loss.clone().detach().requires_grad_(False) try: _lambda = self.l2lam.clone().detach().requires_grad_(False) except: _lambda = 0 _epoch = self.current_epoch _nrmse = ( opt.ip_batch(x_hat - imgs) / opt.ip_batch(imgs)).sqrt().mean().detach().requires_grad_(False) _num_cg = np.max(num_cg) log_dict = { 'lambda': _lambda, 'loss': _loss, 'epoch': self.current_epoch, 'nrmse': _nrmse, 'max_num_cg': _num_cg, 'val_loss': 0., } return { 'loss': loss, 'log': log_dict, 'progress_bar': log_dict, }
def calc_nrmse(gt, pred): return (opt.ip_batch(pred - gt) / opt.ip_batch(gt)).sqrt().mean()
def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True): """A function that implements batched conjugate gradient descent; assumes the first index is batch size. Args: x (Tensor): The initial input to the algorithm. b (Tensor): The residual vector Aop_fun (func): A function performing the A matrix operation. max_iter (int): Maximum number of times to run conjugate gradient descent. l2lam (float): The L2 lambda, or regularization parameter (must be positive). eps (float): Determines how small the residuals must be before termination… verbose (bool): If true, prints extra information to the console. Returns: A tuple containing the updated vector x and the number of iterations performed. """ # explicitly remove r from the computational graph r = b.new_zeros(b.shape, requires_grad=False) # the first calc of the residual may not be necessary in some cases... if l2lam > 0: r = b - (Aop_fun(x) + l2lam * x) else: r = b - Aop_fun(x) p = r rsnot = ip_batch(r) rsold = rsnot rsnew = rsnot eps_squared = eps ** 2 reshape = (-1,) + (1,) * (len(x.shape) - 1) num_iter = 0 for i in range(max_iter): if verbose: print('{i}: {rsnew}'.format(i=i, rsnew=utils.itemize(torch.sqrt(rsnew)))) if rsnew.max() < eps_squared: break if l2lam > 0: Ap = Aop_fun(p) + l2lam * p else: Ap = Aop_fun(p) pAp = dot_batch(p, Ap) #print(utils.itemize(pAp)) alpha = (rsold / pAp).reshape(reshape) x = x + alpha * p r = r - alpha * Ap rsnew = ip_batch(r) beta = (rsnew / rsold).reshape(reshape) rsold = rsnew p = beta * p + r num_iter += 1 if verbose: print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew))) return x, num_iter