Example #1
0
    def __init__(self,
                 maps,
                 mask,
                 l2lam=False,
                 img_shape=None,
                 use_sigpy=False,
                 noncart=False,
                 num_spatial_dims=2):
        super(MultiChannelMRI, self).__init__()
        self.maps = maps
        self.mask = mask
        self.l2lam = l2lam
        self.img_shape = img_shape
        self.noncart = noncart
        self._normal = None
        self.num_spatial_dims = num_spatial_dims

        if self.maps.shape[1] == 1:
            self.single_channel = True
        else:
            self.single_channel = False

        if self.noncart:
            assert use_sigpy, 'Must use SigPy for NUFFT!'

        if use_sigpy:  # FIXME: Not yet Implemented for 3D
            from sigpy import from_pytorch, to_device, Device
            sp_device = Device(self.maps.device.index)
            self.maps = to_device(from_pytorch(self.maps, iscomplex=True),
                                  device=sp_device)
            self.mask = to_device(from_pytorch(self.mask, iscomplex=False),
                                  device=sp_device)
            self.img_shape = self.img_shape[:-1]  # convert R^2N to C^N
            self._build_model_sigpy()
Example #2
0
    def __init__(self,
                 maps,
                 mask,
                 l2lam=False,
                 img_shape=None,
                 use_sigpy=False,
                 noncart=False):
        super(MultiChannelMRI, self).__init__()
        self.maps = maps
        self.mask = mask
        self.l2lam = l2lam
        self.img_shape = img_shape
        self.noncart = noncart
        self._normal = None

        if self.noncart:
            assert use_sigpy, 'Must use SigPy for NUFFT!'

        if use_sigpy:
            from sigpy import from_pytorch, to_device, Device
            sp_device = Device(self.maps.device.index)
            self.maps = to_device(from_pytorch(self.maps, iscomplex=True),
                                  device=sp_device)
            self.mask = to_device(from_pytorch(self.mask, iscomplex=False),
                                  device=sp_device)
            self.img_shape = self.img_shape[:-1]  # convert R^2N to C^N
            self._build_model_sigpy()
Example #3
0
    def __init__(self, maps, mask, phi, l2lam=False, ksp_shape=None, img_shape=None, use_sigpy=True, noncart=True):

        super(MultiBandMRI, self).__init__()
        self.maps = maps
        self.mask = mask
        self.phi = phi
        self.l2lam = l2lam
        self.ksp_shape = ksp_shape
        self.img_shape = img_shape
        self.noncart = noncart
        self._normal = None

        assert self.noncart, 'Non-cartesian implementation only for now'
        assert use_sigpy, 'Must use SigPy for NUFFT'

        if use_sigpy:

            from sigpy import from_pytorch, to_device, Device
            sp_device = Device(self.maps.device.index) # NJM: change to Device(-1) for CPU

            # from real-valued pytorch tensor of shape (batchsize,coils,SMS,N,N,2)
            # to complex-valued numpy/cupy array of shape (batchsize,coils,SMS,N,N)
            self.maps = to_device(from_pytorch(self.maps, iscomplex=True), device=sp_device)

            # from real-valued pytorch tensor of shape (batchsize,nspokes,nreadout,2)
            # to real-valued numpy/cupy array of same shape
            self.mask = to_device(from_pytorch(self.mask, iscomplex=False), device=sp_device)

            # from real-valued pytorch tensor of shape (batchsize,SMS,nspokes,nreadout,2)
            # to complex-valued numpy/cupy array of shape (batchsize,SMS,nspokes,nreadout)
            self.phi = to_device(from_pytorch(self.phi, iscomplex=True), device=sp_device)

            self.img_shape = self.img_shape[:-1] # convert R^2N to C^N
            self.ksp_shape = self.ksp_shape[:-1] # convert R^2N to C^N

            self._build_model_sigpy()
Example #4
0
def main_infer(args):

    if args.recon == 'cgsense':
        MyRecon = CGSenseRecon
    elif args.recon == 'modl':
        MyRecon = MoDLRecon
    elif args.recon == 'resnet':
        MyRecon = ResNetRecon
    elif args.recon == 'dbp':
        MyRecon = DeepBasisPursuitRecon

    # load from checkpoint
    print('loading checkpoint: {}'.format(args.checkpoint_init))
    M = MyRecon.load_from_checkpoint(args.checkpoint_init)

    # create a dataset
    # M.D = MultiChannelMRIDataset()

    # get number of datasets
    useBatchsizeOne = True
    if useBatchsizeOne:
        with h5py.File(args.data_file, 'r') as F:
            imgs = np.array(F['imgs'], dtype=np.complex)
        ndatasets = imgs.shape[0]
        for idx in range(ndatasets):

            print('%i/%i' % (idx, ndatasets))

            # load data
            with h5py.File(args.data_file, 'r') as F:
                imgs = np.array(F['imgs'][idx, ...], dtype=np.complex)
                maps = np.array(F['maps'][idx, ...], dtype=np.complex)
                ksp = np.array(F['ksp'][idx, ...], dtype=np.complex)
                masks = np.array(F['masks'][idx, ...], dtype=np.float)
                if args.multiband:
                    phi = np.array(F['phi'][idx, ...], dtype=np.complex)
            if len(imgs.shape) == 3 and args.multiband:
                imgs, maps, masks, ksp = imgs[None, ...], maps[
                    None, ...], masks[None, ...], ksp[None, ...]
                if args.multiband:
                    phi = phi[None, ...]
            elif len(imgs.shape) == 2:
                imgs, maps, masks, ksp = imgs[None, ...], maps[
                    None, ...], masks[None, ...], ksp[None, ...]

            imgs_torch_pre = cp.c2r(imgs).astype(np.float32)
            maps_torch_pre = cp.c2r(maps).astype(np.float32)
            masks_torch_pre = masks.astype(np.float32)
            ksp_torch_pre = cp.c2r(ksp).astype(np.float32)
            if args.multiband:
                phi_torch_pre = cp.c2r(phi).astype(np.float32)

            print(imgs.shape)
            print(maps.shape)
            print(masks.shape)
            print(ksp.shape)
            if args.multiband:
                print(phi.shape)
            print('')
            print(imgs_torch_pre.shape)
            print(maps_torch_pre.shape)
            print(masks_torch_pre.shape)
            print(ksp_torch_pre.shape)
            if args.multiband:
                print(phi_torch_pre.shape)

            # store a data dictionary in memory
            if args.multiband:
                data = {
                    'imgs': sigpy.to_pytorch(imgs_torch_pre),
                    'maps': sigpy.to_pytorch(maps_torch_pre),
                    'masks': sigpy.to_pytorch(masks_torch_pre),
                    'phi': sigpy.to_pytorch(phi_torch_pre),
                    'out': sigpy.to_pytorch(ksp_torch_pre)
                }
            else:
                data = {
                    'imgs': sigpy.to_pytorch(imgs_torch_pre),
                    'maps': sigpy.to_pytorch(maps_torch_pre),
                    'masks': sigpy.to_pytorch(masks_torch_pre),
                    'out': sigpy.to_pytorch(ksp_torch_pre)
                }

            # batch this data
            M.batch(data)

            # predict output
            output = sigpy.from_pytorch(M(sigpy.to_pytorch(ksp_torch_pre)))

            # complexify
            output = output[..., 0] + 1j * output[..., 1]
            output = output[0, ...]
            # if args.multiband:
            #     output = output[0,::,::,::,0] + 1j*output[0,::,::,::,1]
            # else:
            #     output = output[0,::,::,0] + 1j*output[0,::,::,1]

            # allocate output
            if idx == 0:
                if args.multiband:
                    pred = np.zeros((ndatasets, output.shape[0],
                                     output.shape[1], output.shape[2]),
                                    dtype=output.dtype)
                else:
                    pred = np.zeros(
                        (ndatasets, output.shape[0], output.shape[1]),
                        dtype=output.dtype)

            if args.multiband:
                plt.figure()
                for slc in range(output.shape[0]):
                    plt.subplot(1, output.shape[0], slc + 1)
                    plt.imshow(np.abs(output[slc, ::, ::]), cmap='gray')
                plt.subplots_adjust(wspace=0)
                plt.show()
            else:
                plt.figure()
                plt.imshow(np.abs(output), cmap='gray')
                plt.show()

            # store in output
            if args.multiband:
                pred[idx, ::, ::, ::] = output
            else:
                pred[idx, ::, ::] = output

        np.savez('inference.npz', pred=pred)

    else:

        print('Reading from input file')
        with h5py.File(args.data_file, 'r') as F:
            imgs = np.array(F['imgs'], dtype=np.complex)
            maps = np.array(F['maps'], dtype=np.complex)
            ksp = np.array(F['ksp'], dtype=np.complex)
            masks = np.array(F['masks'], dtype=np.float)
            if args.multiband:
                phi = np.array(F['phi'], dtype=np.complex)
        if len(masks.shape) == 2:
            imgs, maps, masks, ksp = imgs[None, ...], maps[None, ...], masks[
                None, ...], ksp[None, ...]
            if args.multiband:
                phi = phi[None, ...]

        imgs_torch_pre = cp.c2r(imgs).astype(np.float32)
        maps_torch_pre = cp.c2r(maps).astype(np.float32)
        masks_torch_pre = masks.astype(np.float32)
        ksp_torch_pre = cp.c2r(ksp).astype(np.float32)
        if args.multiband:
            phi_torch_pre = cp.c2r(phi).astype(np.float32)

        print(imgs.shape)
        print(maps.shape)
        print(masks.shape)
        print(ksp.shape)
        if args.multiband:
            print(phi.shape)

        print(imgs_torch_pre.shape)
        print(maps_torch_pre.shape)
        print(masks_torch_pre.shape)
        print(ksp_torch_pre.shape)
        if args.multiband:
            print(phi_torch_pre.shape)

        print('Writing to dictionary')
        if args.multiband:
            data = {
                'imgs': sigpy.to_pytorch(imgs_torch_pre),
                'maps': sigpy.to_pytorch(maps_torch_pre),
                'masks': sigpy.to_pytorch(masks_torch_pre),
                'phi': sigpy.to_pytorch(phi_torch_pre),
                'out': sigpy.to_pytorch(ksp_torch_pre)
            }
        else:
            data = {
                'imgs': sigpy.to_pytorch(imgs_torch_pre),
                'maps': sigpy.to_pytorch(maps_torch_pre),
                'masks': sigpy.to_pytorch(masks_torch_pre),
                'out': sigpy.to_pytorch(ksp_torch_pre)
            }
        print('Preparing data batch')
        M.batch(data)

        print('Calling M(y)')
        pred = M(y)

        print(pred.shape)
        np.savez('inference.npz', pred=pred)

    return pred