def forward(self, y): with torch.no_grad(): self.mean_eps = torch.mean(self.eps) #print('epsilon is {}'.format(self.eps)) x = self.A.adjoint(y) z = self.A(x) z_old = z u = z.new_zeros(z.shape) x.requires_grad = False z.requires_grad = False z_old.requires_grad = False u.requires_grad = False self.num_cg = np.zeros(( self.hparams.num_unrolls, self.hparams.num_admm, )) for i in range(self.hparams.num_unrolls): r = self.denoiser(x) for j in range(self.hparams.num_admm): rhs = self.l2lam * self.A.adjoint(z - u) + r fun = lambda xx: self.l2lam * self.A.normal(xx) + xx cg_op = ConjGrad(rhs, fun, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False) x = cg_op.forward(x) n_cg = cg_op.num_cg self.num_cg[i, j] = n_cg Ax_plus_u = self.A(x) + u z_old = z z = y + opt.l2ball_proj_batch(Ax_plus_u - y, self.eps) u = Ax_plus_u - z # check ADMM convergence with torch.no_grad(): Ax = self.A(x) tmp = Ax - z tmp = tmp.contiguous() r_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt() tmp = self.l2lam * self.A.adjoint(z - z_old) tmp = tmp.contiguous() s_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt() if (r_norm + s_norm).max() < 1E-2: if self.debug_level > 0: tqdm.tqdm.write('stopping early, a={}'.format(a)) break tmp = y - Ax self.mean_residual_norm = torch.mean( torch.sqrt(torch.real(opt.zdot_single_batch(tmp)))) return x
def calc_nrmse(gt, pred): resid = pred - gt return (torch.real(opt.zdot_single_batch(resid)) / torch.real(opt.zdot_single_batch(gt))).sqrt().mean()
def _loss_fun(self, pred, gt): resid = pred - gt #return torch.mean(torch.sum(torch.real(torch.conj(resid) * resid))) return torch.mean(torch.real(opt.zdot_single_batch(resid)))
def conjgrad_priv(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True, complex=True): """Conjugate Gradient Algorithm applied to batches; assumes the first index is batch size. Args: x (Tensor): The initial input to the algorithm. b (Tensor): The residual vector Aop_fun (func): A function performing the normal equations, A.adjoint * A max_iter (int): Maximum number of times to run conjugate gradient descent. l2lam (float): The L2 lambda, or regularization parameter (must be positive). eps (float): Determines how small the residuals must be before termination… verbose (bool): If true, prints extra information to the console. complex (bool): If true, uses complex vector space Returns: A tuple containing the output Tensor x and the number of iterations performed. """ if complex: _dot_single_batch = lambda r: zdot_single_batch(r).real _dot_batch = zdot_batch else: _dot_single_batch = dot_single_batch _dot_batch = dot_batch # explicitly remove r from the computational graph #r = b.new_zeros(b.shape, requires_grad=False, dtype=torch.cfloat) # the first calc of the residual may not be necessary in some cases... # note that l2lam can be less than zero when training due to finite # of CG iterations r = b - (Aop_fun(x) + l2lam * x) p = r rsnot = _dot_single_batch(r) rsold = rsnot rsnew = rsnot eps_squared = eps**2 reshape = (-1, ) + (1, ) * (len(x.shape) - 1) num_iter = 0 for i in range(max_iter): if verbose: print('{i}: {rsnew}'.format(i=i, rsnew=utils.itemize( torch.sqrt(rsnew)))) if rsnew.max() < eps_squared: break Ap = Aop_fun(p) + l2lam * p pAp = _dot_batch(p, Ap) #print(utils.itemize(pAp)) alpha = (rsold / pAp).reshape(reshape) x = x + alpha * p r = r - alpha * Ap rsnew = _dot_single_batch(r) beta = (rsnew / rsold).reshape(reshape) rsold = rsnew p = beta * p + r num_iter += 1 if verbose: print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew))) return x, num_iter