Exemplo n.º 1
0
    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart)
        self.x_adj = self.A.adjoint(inp)
Exemplo n.º 2
0
    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        # NJM: support for multiband
        if self.hparams.multiband:
            phi = data['phi']
            self.A = MultiBandMRI(maps, masks, phi, l2lam=0., ksp_shape=phi.shape, img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart)
        else:
            self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart)
        self.x_adj = self.A.adjoint(inp)
Exemplo n.º 3
0
class MoDLReconOneUnroll(torch.nn.Module):
    def __init__(self, denoiser, l2lam, hparams):
        super(MoDLReconOneUnroll, self).__init__()
        self.l2lam = l2lam
        self.num_cg = None
        self.x_adj = None
        self.hparams = hparams
        self.denoiser = denoiser

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps,
                                 masks,
                                 l2lam=0.,
                                 img_shape=data['imgs'].shape,
                                 use_sigpy=self.hparams.use_sigpy,
                                 noncart=self.hparams.noncart)
        if self.hparams.adjoint_data:
            self.x_adj = inp
            if self.A.single_channel:
                self.inp = fft_forw(maps.squeeze(1) * self.x_adj)
        else:
            self.x_adj = self.A.adjoint(inp)
            if self.A.single_channel:
                self.inp = inp.squeeze(1)

    def forward(self, x):

        assert self.x_adj is not None, "x_adj not computed!"
        r = self.denoiser(x)

        if self.A.single_channel:
            # multiply with maps because they might not be all-ones, and they include the fftmod term
            maps = self.A.maps.squeeze(1)
            r_ft = fft_forw(r * maps)
            x_ft_ones = (self.inp + self.l2lam * r_ft) / (1 + self.l2lam)
            x_ft = x_ft_ones * (abs(self.A.mask) != 0) + r_ft * (abs(
                self.A.mask) == 0)
            x = torch.conj(maps) * fft_adj(x_ft)
            self.num_cg = 0
        else:
            cg_op = ConjGrad(self.x_adj + self.l2lam * r,
                             self.A.normal,
                             l2lam=self.l2lam,
                             max_iter=self.hparams.cg_max_iter,
                             eps=self.hparams.cg_eps,
                             verbose=False)
            x = cg_op.forward(x)
            self.num_cg = cg_op.num_cg

        return x

    def get_metadata(self):
        return {
            'num_cg': self.num_cg,
        }
Exemplo n.º 4
0
class CGSenseRecon(Recon):

    def __init__(self, hparams):
        super(CGSenseRecon, self).__init__(hparams)
        self.l2lam = torch.nn.Parameter(torch.tensor(hparams.l2lam_init))
        self.A = None

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart)

        if self.hparams.adjoint_data:
            self.x_adj = inp
        else:
            self.x_adj = self.A.adjoint(inp)

    def forward(self, y):
        cg_op = ConjGrad(self.x_adj, self.A.normal, l2lam=self.l2lam, max_iter=self.hparams.cg_max_iter, eps=self.hparams.cg_eps, verbose=False)
        x_out = cg_op.forward(self.x_adj * 0)
        self.num_cg = cg_op.num_cg
        return x_out

    def get_metadata(self):
        return {
                'num_cg': self.num_cg,
                }
Exemplo n.º 5
0
class UNetRecon(Recon):
    def __init__(self, hparams):
        super(UNetRecon, self).__init__(hparams)
        if self.hparams.network == 'UNet':
            self.network = UNet(batch_norm=self.hparams.batch_norm,
                                l2lam=self.hparams.l2lam_init)

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps,
                                 masks,
                                 l2lam=self.hparams.l2lam_init,
                                 img_shape=data['imgs'].shape,
                                 use_sigpy=self.hparams.use_sigpy,
                                 noncart=self.hparams.noncart)
        self.x_adj = self.A.adjoint(inp)

    def forward(self, y):
        return self.network(self.x_adj)

    def get_metadata(self):
        return {}
Exemplo n.º 6
0
class ResNetRecon(Recon):

    def __init__(self, hparams):
        super(ResNetRecon, self).__init__(hparams)

        copy_shape = np.array(self.D.shape)
        if hparams.num_spatial_dimensions == 2:
            num_channels = 2*np.prod(copy_shape[1:-2])
        elif hparams.num_spatial_dimensions == 3:
            num_channels = 2*np.prod(copy_shape[1:-3])
        else:
            raise ValueError('only 2D or 3D number of spatial dimensions are supported!')
        self.in_channels = num_channels
        
        if self.hparams.network == 'ResNet5Block': # FIX ALSO
            self.network = ResNet5Block(num_filters_start=self.in_channels, num_filters_end=self.in_channels, num_filters=self.hparams.latent_channels, filter_size=7, batch_norm=self.hparams.batch_norm)
        elif self.hparams.network == 'ResNet':
            self.network = ResNet(in_channels=self.in_channels, latent_channels=self.hparams.latent_channels, num_blocks=self.hparams.num_blocks, kernel_size=7, batch_norm=self.hparams.batch_norm)

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps, masks, l2lam=0.,  img_shape=data['imgs'].shape, use_sigpy=self.hparams.use_sigpy, noncart=self.hparams.noncart)
        self.x_adj = self.A.adjoint(inp)

    def forward(self, y):
        return self.network(self.x_adj)

    def get_metadata(self):
        return {}
Exemplo n.º 7
0
    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        with torch.no_grad():
            self.A = MultiChannelMRI(maps,
                                     masks,
                                     l2lam=0.,
                                     img_shape=data['imgs'].shape,
                                     use_sigpy=self.hparams.use_sigpy,
                                     noncart=self.hparams.noncart)
            self.x_adj = self.A.adjoint(inp)

            self.eps = (self.A.maps.shape[1] *
                        torch.sum(self.A.mask.reshape(
                            (self.A.mask.shape[0], -1)),
                                  dim=1)).sqrt() * self.hparams.stdev
Exemplo n.º 8
0
    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps,
                                 masks,
                                 l2lam=0.,
                                 img_shape=data['imgs'].shape,
                                 use_sigpy=self.hparams.use_sigpy,
                                 noncart=self.hparams.noncart)
        if self.hparams.adjoint_data:
            self.x_adj = inp
            if self.A.single_channel:
                self.inp = fft_forw(maps.squeeze(1) * self.x_adj)
        else:
            self.x_adj = self.A.adjoint(inp)
            if self.A.single_channel:
                self.inp = inp.squeeze(1)
Exemplo n.º 9
0
    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        with torch.no_grad():
            self.A = MultiChannelMRI(maps,
                                     masks,
                                     l2lam=0.,
                                     img_shape=data['imgs'].shape,
                                     use_sigpy=self.hparams.use_sigpy,
                                     noncart=self.hparams.noncart)
            self.x_adj = self.A.adjoint(inp)

            # FIXME: look at modl for single channel support
            assert self.A.single_channel is False, 'single channel support not yet implemented!'
            self.eps = (self.A.maps.shape[1] *
                        torch.sum(self.A.mask.reshape(
                            (self.A.mask.shape[0], -1)),
                                  dim=1)).sqrt() * self.hparams.stdev
Exemplo n.º 10
0
class MoDLReconOneUnroll(torch.nn.Module):
    def __init__(self, denoiser, l2lam, hparams):
        super(MoDLReconOneUnroll, self).__init__()
        self.l2lam = l2lam
        self.num_cg = None
        self.x_adj = None
        self.hparams = hparams
        self.denoiser = denoiser

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps,
                                 masks,
                                 l2lam=0.,
                                 img_shape=data['imgs'].shape,
                                 use_sigpy=self.hparams.use_sigpy,
                                 noncart=self.hparams.noncart)
        if self.hparams.adjoint_data:
            self.x_adj = inp
        else:
            self.x_adj = self.A.adjoint(inp)

    def forward(self, x):

        assert self.x_adj is not None, "x_adj not computed!"
        r = self.denoiser(x)

        cg_op = ConjGrad(self.x_adj + self.l2lam * r,
                         self.A.normal,
                         l2lam=self.l2lam,
                         max_iter=self.hparams.cg_max_iter,
                         eps=self.hparams.cg_eps,
                         verbose=False)
        x = cg_op.forward(x)
        self.num_cg = cg_op.num_cg

        return x

    def get_metadata(self):
        return {
            'num_cg': self.num_cg,
        }
Exemplo n.º 11
0
class ResNetRecon(Recon):

    def __init__(self, args):
        super(ResNetRecon, self).__init__(args)

        if args.network == 'ResNet5Block':
            self.denoiser = ResNet5Block(num_filters=args.latent_channels, filter_size=7, batch_norm=args.batch_norm)
        elif args.network == 'ResNet':
            self.denoiser = ResNet(latent_channels=args.latent_channels, num_blocks=args.num_blocks, kernel_size=7, batch_norm=args.batch_norm)

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps, masks, l2lam=0.,  img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart)
        self.x_adj = self.A.adjoint(inp)

    def forward(self, y):
        return self.denoiser(self.x_adj)

    def get_metadata(self):
        return {}
Exemplo n.º 12
0
class DeepBasisPursuitRecon(Recon):

    def __init__(self, args):
        super(DeepBasisPursuitRecon, self).__init__(args)
        self.l2lam = torch.nn.Parameter(torch.tensor(args.l2lam_init))
        self.num_admm = args.num_admm

        if args.network == 'ResNet5Block':
            self.denoiser = ResNet5Block(num_filters=args.latent_channels, filter_size=7, batch_norm=args.batch_norm)
        elif args.network == 'ResNet':
            self.denoiser = ResNet(latent_channels=args.latent_channels, num_blocks=args.num_blocks, kernel_size=7, batch_norm=args.batch_norm)

        self.debug_level = 0

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        self.A = MultiChannelMRI(maps, masks, l2lam=0., img_shape=data['imgs'].shape, use_sigpy=self.use_sigpy, noncart=self.noncart)
        self.x_adj = self.A.adjoint(inp)

    def forward(self, y):
        eps = opt.ip_batch(self.A.maps.shape[1] * self.A.mask.sum((1, 2))).sqrt() * self.stdev
        x = self.A.adjoint(y)
        z = self.A(x)
        z_old = z
        u = z.new_zeros(z.shape)

        x.requires_grad = False
        z.requires_grad = False
        z_old.requires_grad = False
        u.requires_grad = False

        self.num_cg = np.zeros((self.num_unrolls,self.num_admm,))

        for i in range(self.num_unrolls):
            r = self.denoiser(x)

            for j in range(self.num_admm):

                rhs = self.l2lam * self.A.adjoint(z - u) + r
                fun = lambda xx: self.l2lam * self.A.normal(xx) + xx
                cg_op = ConjGrad(rhs, fun, max_iter=self.cg_max_iter, eps=self.eps, verbose=False)
                x = cg_op.forward(x)
                n_cg = cg_op.num_cg
                self.num_cg[i, j] = n_cg

                Ax_plus_u = self.A(x) + u
                z_old = z
                z = y + opt.l2ball_proj_batch(Ax_plus_u - y, eps)
                u = Ax_plus_u - z

                # check ADMM convergence
                Ax = self.A(x)
                r_norm = opt.ip_batch(Ax-z).sqrt()
                s_norm = opt.ip_batch(self.l2lam * self.A.adjoint(z - z_old)).sqrt()
                if (r_norm + s_norm).max() < 1E-2:
                    if self.debug_level > 0:
                        tqdm.tqdm.write('stopping early, a={}'.format(a))
                    break
        return x

    def get_metadata(self):
        return {
                'num_cg': self.num_cg.ravel(),
                }
Exemplo n.º 13
0
class DeepBasisPursuitRecon(Recon):
    def __init__(self, hparams):
        super(DeepBasisPursuitRecon, self).__init__(hparams)
        self.l2lam = torch.nn.Parameter(torch.tensor(hparams.l2lam_init))
        self.num_admm = hparams.num_admm

        copy_shape = np.array(self.D.shape)
        if hparams.num_spatial_dimensions == 2:
            num_channels = 2 * np.prod(copy_shape[1:-2])
        elif hparams.num_spatial_dimensions == 3:
            num_channels = 2 * np.prod(copy_shape[1:-3])
        else:
            raise ValueError(
                'only 2D or 3D number of spatial dimensions are supported!')
        self.in_channels = num_channels

        if hparams.network == 'ResNet5Block':
            self.denoiser = ResNet5Block(num_filters_start=self.in_channels,
                                         num_filters_end=self.in_channels,
                                         num_filters=hparams.latent_channels,
                                         filter_size=7,
                                         batch_norm=hparams.batch_norm)
        elif hparams.network == 'ResNet':
            self.denoiser = ResNet(in_channels=self.in_channels,
                                   latent_channels=hparams.latent_channels,
                                   num_blocks=hparams.num_blocks,
                                   kernel_size=7,
                                   batch_norm=hparams.batch_norm)

        self.debug_level = 0

        self.mean_residual_norm = 0

    def batch(self, data):

        maps = data['maps']
        masks = data['masks']
        inp = data['out']

        with torch.no_grad():
            self.A = MultiChannelMRI(maps,
                                     masks,
                                     l2lam=0.,
                                     img_shape=data['imgs'].shape,
                                     use_sigpy=self.hparams.use_sigpy,
                                     noncart=self.hparams.noncart)
            self.x_adj = self.A.adjoint(inp)

            # FIXME: look at modl for single channel support
            assert self.A.single_channel is False, 'single channel support not yet implemented!'
            self.eps = (self.A.maps.shape[1] *
                        torch.sum(self.A.mask.reshape(
                            (self.A.mask.shape[0], -1)),
                                  dim=1)).sqrt() * self.hparams.stdev

    def forward(self, y):
        with torch.no_grad():
            self.mean_eps = torch.mean(self.eps)
            #print('epsilon is {}'.format(self.eps))
        x = self.A.adjoint(y)
        z = self.A(x)
        z_old = z
        u = z.new_zeros(z.shape)

        x.requires_grad = False
        z.requires_grad = False
        z_old.requires_grad = False
        u.requires_grad = False

        self.num_cg = np.zeros((
            self.hparams.num_unrolls,
            self.hparams.num_admm,
        ))

        for i in range(self.hparams.num_unrolls):
            r = self.denoiser(x)

            for j in range(self.hparams.num_admm):

                rhs = self.l2lam * self.A.adjoint(z - u) + r
                fun = lambda xx: self.l2lam * self.A.normal(xx) + xx
                cg_op = ConjGrad(rhs,
                                 fun,
                                 max_iter=self.hparams.cg_max_iter,
                                 eps=self.hparams.cg_eps,
                                 verbose=False)
                x = cg_op.forward(x)
                n_cg = cg_op.num_cg
                self.num_cg[i, j] = n_cg

                Ax_plus_u = self.A(x) + u
                z_old = z
                z = y + opt.l2ball_proj_batch(Ax_plus_u - y, self.eps)
                u = Ax_plus_u - z

                # check ADMM convergence
                with torch.no_grad():
                    Ax = self.A(x)
                    tmp = Ax - z
                    tmp = tmp.contiguous()
                    r_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    tmp = self.l2lam * self.A.adjoint(z - z_old)
                    tmp = tmp.contiguous()
                    s_norm = torch.real(opt.zdot_single_batch(tmp)).sqrt()

                    if (r_norm + s_norm).max() < 1E-2:
                        if self.debug_level > 0:
                            tqdm.tqdm.write('stopping early, a={}'.format(a))
                        break
                    tmp = y - Ax
                    self.mean_residual_norm = torch.mean(
                        torch.sqrt(torch.real(opt.zdot_single_batch(tmp))))
        return x

    def get_metadata(self):
        return {
            'num_cg': self.num_cg.ravel(),
            'mean_residual_norm': self.mean_residual_norm,
            'mean_eps': self.mean_eps,
        }