def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True):
    ''' batched conjugate gradient descent. assumes the first index is batch size '''

    # explicitly remove r from the computational graph
    r = b.new_zeros(b.shape, requires_grad=False)

    # the first calc of the residual may not be necessary in some cases...
    if l2lam > 0:
        r = b - (Aop_fun(x) + l2lam * x)
    else:
        r = b - Aop_fun(x)
    p = r

    rsnot = ip_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps**2

    reshape = (-1, ) + (1, ) * (len(x.shape) - 1)

    num_iter = 0
    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i,
                                        rsnew=utils.itemize(
                                            torch.sqrt(rsnew))))

        if rsnew.max() < eps:
            break

        if l2lam > 0:
            Ap = Aop_fun(p) + l2lam * p
        else:
            Ap = Aop_fun(p)
        pAp = dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = ip_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1

    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter
Beispiel #2
0
 def get_progress_bar_dict(self):
     items = super().get_progress_bar_dict()
     if self.log_dict:
         for key in self.log_dict.keys():
             if type(self.log_dict[key]) == torch.Tensor:
                 items[key] = utils.itemize(self.log_dict[key])
             else:
                 items[key] = self.log_dict[key]
     return items
Beispiel #3
0
    def training_step(self, batch, batch_nb):
        idx, data = batch
        idx = utils.itemize(idx)
        imgs = data['imgs']
        inp = data['out']

        self.batch(data)

        x_hat = self.forward(inp)

        try:
            num_cg = self.get_metadata()['num_cg']
        except KeyError:
            num_cg = 0

        _b = inp.shape[0]
        if _b == 1 and idx == 0:
            _idx = 0
        elif _b > 1 and 0 in idx:
            _idx = idx.index(0)
        else:
            _idx = None
        if _idx is not None:
            with torch.no_grad():
                if self.x_adj is None:
                    x_adj = self.A.adjoint(inp)
                else:
                    x_adj = self.x_adj
                _x_hat = utils.t2n(x_hat[_idx, ...])
                _x_gt = utils.t2n(imgs[_idx, ...])
                _x_adj = utils.t2n(x_adj[_idx, ...])

                myim = torch.tensor(
                    np.stack((np.abs(_x_hat), np.angle(_x_hat)),
                             axis=0))[:, None, ...]
                grid = make_grid(myim,
                                 scale_each=True,
                                 normalize=True,
                                 nrow=8,
                                 pad_value=10)
                self.logger.experiment.add_image('2_train_prediction', grid,
                                                 self.current_epoch)

                if self.current_epoch == 0:
                    myim = torch.tensor(
                        np.stack((np.abs(_x_gt), np.angle(_x_gt)),
                                 axis=0))[:, None, ...]
                    grid = make_grid(myim,
                                     scale_each=True,
                                     normalize=True,
                                     nrow=8,
                                     pad_value=10)
                    self.logger.experiment.add_image('1_ground_truth', grid, 0)

                    myim = torch.tensor(
                        np.stack((np.abs(_x_adj), np.angle(_x_adj)),
                                 axis=0))[:, None, ...]
                    grid = make_grid(myim,
                                     scale_each=True,
                                     normalize=True,
                                     nrow=8,
                                     pad_value=10)
                    self.logger.experiment.add_image('0_input', grid, 0)

        if self.self_supervised:
            pred = self.A.forward(x_hat)
            gt = inp
        else:
            pred = x_hat
            gt = imgs

        loss = self.loss_fun(pred, gt)

        _loss = loss.clone().detach().requires_grad_(False)
        try:
            _lambda = self.l2lam.clone().detach().requires_grad_(False)
        except:
            _lambda = 0
        _epoch = self.current_epoch
        _nrmse = (
            opt.ip_batch(x_hat - imgs) /
            opt.ip_batch(imgs)).sqrt().mean().detach().requires_grad_(False)
        _num_cg = np.max(num_cg)

        log_dict = {
            'lambda': _lambda,
            'loss': _loss,
            'epoch': self.current_epoch,
            'nrmse': _nrmse,
            'max_num_cg': _num_cg,
            'val_loss': 0.,
        }
        return {
            'loss': loss,
            'log': log_dict,
            'progress_bar': log_dict,
        }
Beispiel #4
0
    def training_step(self, batch, batch_nb):
        """Defines a training step solving deep inverse problems, including batching, performing a forward pass through
        the model, and logging data. This may either be supervised or unsupervised based on hyperparameters.

        Args:
            batch (tuple): Should hold the indices of data and the corresponding data, in said order.
            batch_nb (None): Currently unimplemented.

        Returns:
            A dict holding performance data and current epoch for performance tracking over time.
        """

        idx, data = batch
        idx = utils.itemize(idx)
        imgs = data['imgs']
        inp = data['out']

        self.batch(data)

        x_hat = self.forward(inp)

        try:
            num_cg = self.get_metadata()['num_cg']
        except KeyError:
            num_cg = 0

        if self.logger and (
                self.current_epoch % self.hparams.save_every_N_epochs == 0
                or self.current_epoch == self.hparams.num_epochs - 1):
            _b = inp.shape[0]
            if _b == 1 and idx == 0:
                _idx = 0
            elif _b > 1 and 0 in idx:
                _idx = idx.index(0)
            else:
                _idx = None
            if _idx is not None:
                with torch.no_grad():
                    if self.x_adj is None:
                        x_adj = self.A.adjoint(inp)
                    else:
                        x_adj = self.x_adj

                    _x_hat = utils.t2n(x_hat[_idx, ...])
                    _x_gt = utils.t2n(imgs[_idx, ...])
                    _x_adj = utils.t2n(x_adj[_idx, ...])

                    if len(_x_hat.shape) > 2:
                        _d = tuple(range(len(_x_hat.shape) - 2))
                        _x_hat_rss = np.linalg.norm(_x_hat, axis=_d)
                        _x_gt_rss = np.linalg.norm(_x_gt, axis=_d)
                        _x_adj_rss = np.linalg.norm(_x_adj, axis=_d)

                        myim = torch.tensor(
                            np.stack((_x_adj_rss, _x_hat_rss, _x_gt_rss),
                                     axis=0))[:, None, ...]
                        grid = make_grid(myim,
                                         scale_each=True,
                                         normalize=True,
                                         nrow=8,
                                         pad_value=10)
                        self.logger.experiment.add_image(
                            '3_train_prediction_rss', grid, self.current_epoch)

                        while len(_x_hat.shape) > 2:
                            _x_hat = _x_hat[0, ...]
                            _x_gt = _x_gt[0, ...]
                            _x_adj = _x_adj[0, ...]

                    myim = torch.tensor(
                        np.stack((np.abs(_x_hat), np.angle(_x_hat)),
                                 axis=0))[:, None, ...]
                    grid = make_grid(myim,
                                     scale_each=True,
                                     normalize=True,
                                     nrow=8,
                                     pad_value=10)
                    self.logger.experiment.add_image('2_train_prediction',
                                                     grid, self.current_epoch)

                    if self.current_epoch == 0:
                        myim = torch.tensor(
                            np.stack((np.abs(_x_gt), np.angle(_x_gt)),
                                     axis=0))[:, None, ...]
                        grid = make_grid(myim,
                                         scale_each=True,
                                         normalize=True,
                                         nrow=8,
                                         pad_value=10)
                        self.logger.experiment.add_image(
                            '1_ground_truth', grid, 0)

                        myim = torch.tensor(
                            np.stack((np.abs(_x_adj), np.angle(_x_adj)),
                                     axis=0))[:, None, ...]
                        grid = make_grid(myim,
                                         scale_each=True,
                                         normalize=True,
                                         nrow=8,
                                         pad_value=10)
                        self.logger.experiment.add_image('0_input', grid, 0)

        if self.hparams.self_supervised:
            pred = self.A.forward(x_hat)
            gt = inp
        else:
            pred = x_hat
            gt = imgs

        loss = self.loss_fun(pred, gt)

        _loss = loss.clone().detach().requires_grad_(False)
        try:
            _lambda = self.l2lam.clone().detach().requires_grad_(False)
        except:
            _lambda = 0
        _epoch = self.current_epoch
        _nrmse = calc_nrmse(imgs, x_hat).detach().requires_grad_(False)
        _num_cg = np.max(num_cg)

        log_dict = {
            'lambda': _lambda,
            'train_loss': _loss,
            'epoch': self.current_epoch,
            'nrmse': _nrmse,
            'max_num_cg': _num_cg,
            'val_loss': 0.,
        }

        if self.logger:
            for key in log_dict.keys():
                self.logger.experiment.add_scalar(key, log_dict[key],
                                                  self.global_step)

        return {
            'loss': loss,
            'progress_bar': log_dict,
        }
Beispiel #5
0
def conjgrad(x, b, Aop_fun, max_iter=10, l2lam=0., eps=1e-4, verbose=True):
    """A function that implements batched conjugate gradient descent; assumes the first index is batch size.

    Args:
    	x (Tensor): The initial input to the algorithm.
    b (Tensor): The residual vector
    Aop_fun (func): A function performing the A matrix operation.
    max_iter (int): Maximum number of times to run conjugate gradient descent.
    l2lam (float): The L2 lambda, or regularization parameter (must be positive).
    eps (float): Determines how small the residuals must be before termination…
    verbose (bool): If true, prints extra information to the console.

    Returns:
    	A tuple containing the updated vector x and the number of iterations performed.
    """

    # explicitly remove r from the computational graph
    r = b.new_zeros(b.shape, requires_grad=False)

    # the first calc of the residual may not be necessary in some cases...
    if l2lam > 0:
        r = b - (Aop_fun(x) + l2lam * x)
    else:
        r = b - Aop_fun(x)
    p = r

    rsnot = ip_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps ** 2

    reshape = (-1,) + (1,) * (len(x.shape) - 1)

    num_iter = 0
    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i, rsnew=utils.itemize(torch.sqrt(rsnew))))

        if rsnew.max() < eps_squared:
            break

        if l2lam > 0:
            Ap = Aop_fun(p) + l2lam * p
        else:
            Ap = Aop_fun(p)
        pAp = dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = ip_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1


    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter
Beispiel #6
0
def conjgrad_priv(x,
                  b,
                  Aop_fun,
                  max_iter=10,
                  l2lam=0.,
                  eps=1e-4,
                  verbose=True,
                  complex=True):
    """Conjugate Gradient Algorithm applied to batches; assumes the first index is batch size.

    Args:
    x (Tensor): The initial input to the algorithm.
    b (Tensor): The residual vector
    Aop_fun (func): A function performing the normal equations, A.adjoint * A
    max_iter (int): Maximum number of times to run conjugate gradient descent.
    l2lam (float): The L2 lambda, or regularization parameter (must be positive).
    eps (float): Determines how small the residuals must be before termination…
    verbose (bool): If true, prints extra information to the console.
    complex (bool): If true, uses complex vector space

    Returns:
    	A tuple containing the output Tensor x and the number of iterations performed.
    """

    if complex:
        _dot_single_batch = lambda r: zdot_single_batch(r).real
        _dot_batch = zdot_batch
    else:
        _dot_single_batch = dot_single_batch
        _dot_batch = dot_batch

    # explicitly remove r from the computational graph
    #r = b.new_zeros(b.shape, requires_grad=False, dtype=torch.cfloat)

    # the first calc of the residual may not be necessary in some cases...
    # note that l2lam can be less than zero when training due to finite # of CG iterations
    r = b - (Aop_fun(x) + l2lam * x)
    p = r

    rsnot = _dot_single_batch(r)
    rsold = rsnot
    rsnew = rsnot

    eps_squared = eps**2

    reshape = (-1, ) + (1, ) * (len(x.shape) - 1)

    num_iter = 0

    for i in range(max_iter):

        if verbose:
            print('{i}: {rsnew}'.format(i=i,
                                        rsnew=utils.itemize(
                                            torch.sqrt(rsnew))))

        if rsnew.max() < eps_squared:
            break

        Ap = Aop_fun(p) + l2lam * p
        pAp = _dot_batch(p, Ap)

        #print(utils.itemize(pAp))

        alpha = (rsold / pAp).reshape(reshape)

        x = x + alpha * p
        r = r - alpha * Ap

        rsnew = _dot_single_batch(r)

        beta = (rsnew / rsold).reshape(reshape)

        rsold = rsnew

        p = beta * p + r
        num_iter += 1

    if verbose:
        print('FINAL: {rsnew}'.format(rsnew=torch.sqrt(rsnew)))

    return x, num_iter