Пример #1
0
    def forward(self, y):
        with torch.no_grad():
            self.mean_eps = torch.mean(self.eps)
            #print('epsilon is {}'.format(self.eps))
        x = self.A.adjoint(y)
        z = self.A(x)
        z_old = z
        u = z.new_zeros(z.shape)

        x.requires_grad = False
        z.requires_grad = False
        z_old.requires_grad = False
        u.requires_grad = False

        self.num_cg = np.zeros((
            self.hparams.num_unrolls,
            self.hparams.num_admm,
        ))

        for i in range(self.hparams.num_unrolls):
            r = self.denoiser(x)

            for j in range(self.hparams.num_admm):

                rhs = self.l2lam * self.A.adjoint(z - u) + r
                fun = lambda xx: self.l2lam * self.A.normal(xx) + xx
                cg_op = ConjGrad(rhs,
                                 fun,
                                 max_iter=self.hparams.cg_max_iter,
                                 eps=self.hparams.cg_eps,
                                 verbose=False)
                x = cg_op.forward(x)
                n_cg = cg_op.num_cg
                self.num_cg[i, j] = n_cg

                Ax_plus_u = self.A(x) + u
                z_old = z
                z = y + opt.l2ball_proj_batch(Ax_plus_u - y, self.eps)
                u = Ax_plus_u - z

                # check ADMM convergence
                with torch.no_grad():
                    Ax = self.A(x)
                    tmp = Ax - z
                    tmp = tmp.contiguous()
                    r_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    tmp = self.l2lam * self.A.adjoint(z - z_old)
                    tmp = tmp.contiguous()
                    s_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    if (r_norm + s_norm).max() < 1E-2:
                        if self.debug_level > 0:
                            tqdm.tqdm.write('stopping early, a={}'.format(a))
                        break
                    tmp = y - Ax
                    self.mean_residual_norm = torch.mean(
                        torch.sqrt(torch.real(opt.zdot_single_batch(tmp))))
        return x
Пример #2
0
def calc_nrmse(gt, pred):
    resid = pred - gt
    return (torch.real(opt.zdot_single_batch(resid)) /
            torch.real(opt.zdot_single_batch(gt))).sqrt().mean()
Пример #3
0
 def _loss_fun(self, pred, gt):
     resid = pred - gt
     #return torch.mean(torch.sum(torch.real(torch.conj(resid) * resid)))
     return torch.mean(torch.real(opt.zdot_single_batch(resid)))
Пример #4
0
def conjgrad_priv(x,
                  b,
                  Aop_fun,
                  max_iter=10,
                  l2lam=0.,
                  eps=1e-4,
                  verbose=True,
                  complex=True):
    """Conjugate Gradient Algorithm applied to batches; assumes the first index is batch size.

    Args:
    x (Tensor): The initial input to the algorithm.
    b (Tensor): The residual vector
    Aop_fun (func): A function performing the normal equations, A.adjoint * A
    max_iter (int): Maximum number of times to run conjugate gradient descent.
    l2lam (float): The L2 lambda, or regularization parameter (must be positive).
    eps (float): Determines how small the residuals must be before termination…
    verbose (bool): If true, prints extra information to the console.
    complex (bool): If true, uses complex vector space

    Returns:
    	A tuple containing the output Tensor x and the number of iterations performed.
    """

    if complex:
        _dot_single_batch = lambda r: zdot_single_batch(r).real
        _dot_batch = zdot_batch
    else:
        _dot_single_batch = dot_single_batch
        _dot_batch = dot_batch

    # explicitly remove r from the computational graph
    #r = b.new_zeros(b.shape, requires_grad=False, dtype=torch.cfloat)

    # the first calc of the residual may not be necessary in some cases...
    # note that l2lam can be less than zero when training due to finite # of CG iterations
    r = b - (Aop_fun(x) + l2lam * x)
    p = r

    rsnot = _dot_single_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps**2

    reshape = (-1, ) + (1, ) * (len(x.shape) - 1)

    num_iter = 0

    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i,
                                        rsnew=utils.itemize(
                                            torch.sqrt(rsnew))))

        if rsnew.max() < eps_squared:
            break

        Ap = Aop_fun(p) + l2lam * p
        pAp = _dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = _dot_single_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1

    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter