def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True):
    ''' batched conjugate gradient descent. assumes the first index is batch size '''

    # explicitly remove r from the computational graph
    r = b.new_zeros(b.shape, requires_grad=False)

    # the first calc of the residual may not be necessary in some cases...
    if l2lam > 0:
        r = b - (Aop_fun(x) + l2lam * x)
    else:
        r = b - Aop_fun(x)
    p = r

    rsnot = ip_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps**2

    reshape = (-1, ) + (1, ) * (len(x.shape) - 1)

    num_iter = 0
    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i,
                                        rsnew=utils.itemize(
                                            torch.sqrt(rsnew))))

        if rsnew.max() < eps:
            break

        if l2lam > 0:
            Ap = Aop_fun(p) + l2lam * p
        else:
            Ap = Aop_fun(p)
        pAp = dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = ip_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1

    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter
Beispiel #2
0
    def forward(self, y):
        eps = opt.ip_batch(self.A.maps.shape[1] * self.A.mask.sum(
            (1, 2))).sqrt() * self.hparams.stdev
        x = self.A.adjoint(y)
        z = self.A(x)
        z_old = z
        u = z.new_zeros(z.shape)

        x.requires_grad = False
        z.requires_grad = False
        z_old.requires_grad = False
        u.requires_grad = False

        self.num_cg = np.zeros((
            self.hparams.num_unrolls,
            self.hparams.num_admm,
        ))

        for i in range(self.hparams.num_unrolls):
            r = self.denoiser(x)

            for j in range(self.hparams.num_admm):

                rhs = self.l2lam * self.A.adjoint(z - u) + r
                fun = lambda xx: self.l2lam * self.A.normal(xx) + xx
                cg_op = ConjGrad(rhs,
                                 fun,
                                 max_iter=self.hparams.cg_max_iter,
                                 eps=self.hparams.cg_eps,
                                 verbose=False)
                x = cg_op.forward(x)
                n_cg = cg_op.num_cg
                self.num_cg[i, j] = n_cg

                Ax_plus_u = self.A(x) + u
                z_old = z
                z = y + opt.l2ball_proj_batch(Ax_plus_u - y, eps)
                u = Ax_plus_u - z

                # check ADMM convergence
                Ax = self.A(x)
                r_norm = opt.ip_batch(Ax - z).sqrt()
                s_norm = opt.ip_batch(self.l2lam *
                                      self.A.adjoint(z - z_old)).sqrt()
                if (r_norm + s_norm).max() < 1E-2:
                    if self.debug_level > 0:
                        tqdm.tqdm.write('stopping early, a={}'.format(a))
                    break
        return x
Beispiel #3
0
    def training_step(self, batch, batch_nb):
        idx, data = batch
        idx = utils.itemize(idx)
        imgs = data['imgs']
        inp = data['out']

        self.batch(data)

        x_hat = self.forward(inp)

        try:
            num_cg = self.get_metadata()['num_cg']
        except KeyError:
            num_cg = 0

        _b = inp.shape[0]
        if _b == 1 and idx == 0:
            _idx = 0
        elif _b > 1 and 0 in idx:
            _idx = idx.index(0)
        else:
            _idx = None
        if _idx is not None:
            with torch.no_grad():
                if self.x_adj is None:
                    x_adj = self.A.adjoint(inp)
                else:
                    x_adj = self.x_adj
                _x_hat = utils.t2n(x_hat[_idx, ...])
                _x_gt = utils.t2n(imgs[_idx, ...])
                _x_adj = utils.t2n(x_adj[_idx, ...])

                myim = torch.tensor(
                    np.stack((np.abs(_x_hat), np.angle(_x_hat)),
                             axis=0))[:, None, ...]
                grid = make_grid(myim,
                                 scale_each=True,
                                 normalize=True,
                                 nrow=8,
                                 pad_value=10)
                self.logger.experiment.add_image('2_train_prediction', grid,
                                                 self.current_epoch)

                if self.current_epoch == 0:
                    myim = torch.tensor(
                        np.stack((np.abs(_x_gt), np.angle(_x_gt)),
                                 axis=0))[:, None, ...]
                    grid = make_grid(myim,
                                     scale_each=True,
                                     normalize=True,
                                     nrow=8,
                                     pad_value=10)
                    self.logger.experiment.add_image('1_ground_truth', grid, 0)

                    myim = torch.tensor(
                        np.stack((np.abs(_x_adj), np.angle(_x_adj)),
                                 axis=0))[:, None, ...]
                    grid = make_grid(myim,
                                     scale_each=True,
                                     normalize=True,
                                     nrow=8,
                                     pad_value=10)
                    self.logger.experiment.add_image('0_input', grid, 0)

        if self.self_supervised:
            pred = self.A.forward(x_hat)
            gt = inp
        else:
            pred = x_hat
            gt = imgs

        loss = self.loss_fun(pred, gt)

        _loss = loss.clone().detach().requires_grad_(False)
        try:
            _lambda = self.l2lam.clone().detach().requires_grad_(False)
        except:
            _lambda = 0
        _epoch = self.current_epoch
        _nrmse = (
            opt.ip_batch(x_hat - imgs) /
            opt.ip_batch(imgs)).sqrt().mean().detach().requires_grad_(False)
        _num_cg = np.max(num_cg)

        log_dict = {
            'lambda': _lambda,
            'loss': _loss,
            'epoch': self.current_epoch,
            'nrmse': _nrmse,
            'max_num_cg': _num_cg,
            'val_loss': 0.,
        }
        return {
            'loss': loss,
            'log': log_dict,
            'progress_bar': log_dict,
        }
Beispiel #4
0
def calc_nrmse(gt, pred):
    return (opt.ip_batch(pred - gt) / opt.ip_batch(gt)).sqrt().mean()
Beispiel #5
0
def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True):
    """A function that implements batched conjugate gradient descent; assumes the first index is batch size.

    Args:
    	x (Tensor): The initial input to the algorithm.
    b (Tensor): The residual vector
    Aop_fun (func): A function performing the A matrix operation.
    max_iter (int): Maximum number of times to run conjugate gradient descent.
    l2lam (float): The L2 lambda, or regularization parameter (must be positive).
    eps (float): Determines how small the residuals must be before termination…
    verbose (bool): If true, prints extra information to the console.

    Returns:
    	A tuple containing the updated vector x and the number of iterations performed.
    """

    # explicitly remove r from the computational graph
    r = b.new_zeros(b.shape, requires_grad=False)

    # the first calc of the residual may not be necessary in some cases...
    if l2lam > 0:
        r = b - (Aop_fun(x) + l2lam * x)
    else:
        r = b - Aop_fun(x)
    p = r

    rsnot = ip_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps ** 2

    reshape = (-1,) + (1,) * (len(x.shape) - 1)

    num_iter = 0
    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i, rsnew=utils.itemize(torch.sqrt(rsnew))))

        if rsnew.max() < eps_squared:
            break

        if l2lam > 0:
            Ap = Aop_fun(p) + l2lam * p
        else:
            Ap = Aop_fun(p)
        pAp = dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = ip_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1


    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter