def __init__(self, maps, mask, l2lam=False, img_shape=None, use_sigpy=False, noncart=False, num_spatial_dims=2): super(MultiChannelMRI, self).__init__() self.maps = maps self.mask = mask self.l2lam = l2lam self.img_shape = img_shape self.noncart = noncart self._normal = None self.num_spatial_dims = num_spatial_dims if self.maps.shape[1] == 1: self.single_channel = True else: self.single_channel = False if self.noncart: assert use_sigpy, 'Must use SigPy for NUFFT!' if use_sigpy: # FIXME: Not yet Implemented for 3D from sigpy import from_pytorch, to_device, Device sp_device = Device(self.maps.device.index) self.maps = to_device(from_pytorch(self.maps, iscomplex=True), device=sp_device) self.mask = to_device(from_pytorch(self.mask, iscomplex=False), device=sp_device) self.img_shape = self.img_shape[:-1] # convert R^2N to C^N self._build_model_sigpy()
def __init__(self, maps, mask, l2lam=False, img_shape=None, use_sigpy=False, noncart=False): super(MultiChannelMRI, self).__init__() self.maps = maps self.mask = mask self.l2lam = l2lam self.img_shape = img_shape self.noncart = noncart self._normal = None if self.noncart: assert use_sigpy, 'Must use SigPy for NUFFT!' if use_sigpy: from sigpy import from_pytorch, to_device, Device sp_device = Device(self.maps.device.index) self.maps = to_device(from_pytorch(self.maps, iscomplex=True), device=sp_device) self.mask = to_device(from_pytorch(self.mask, iscomplex=False), device=sp_device) self.img_shape = self.img_shape[:-1] # convert R^2N to C^N self._build_model_sigpy()
def __init__(self, maps, mask, phi, l2lam=False, ksp_shape=None, img_shape=None, use_sigpy=True, noncart=True): super(MultiBandMRI, self).__init__() self.maps = maps self.mask = mask self.phi = phi self.l2lam = l2lam self.ksp_shape = ksp_shape self.img_shape = img_shape self.noncart = noncart self._normal = None assert self.noncart, 'Non-cartesian implementation only for now' assert use_sigpy, 'Must use SigPy for NUFFT' if use_sigpy: from sigpy import from_pytorch, to_device, Device sp_device = Device(self.maps.device.index) # NJM: change to Device(-1) for CPU # from real-valued pytorch tensor of shape (batchsize,coils,SMS,N,N,2) # to complex-valued numpy/cupy array of shape (batchsize,coils,SMS,N,N) self.maps = to_device(from_pytorch(self.maps, iscomplex=True), device=sp_device) # from real-valued pytorch tensor of shape (batchsize,nspokes,nreadout,2) # to real-valued numpy/cupy array of same shape self.mask = to_device(from_pytorch(self.mask, iscomplex=False), device=sp_device) # from real-valued pytorch tensor of shape (batchsize,SMS,nspokes,nreadout,2) # to complex-valued numpy/cupy array of shape (batchsize,SMS,nspokes,nreadout) self.phi = to_device(from_pytorch(self.phi, iscomplex=True), device=sp_device) self.img_shape = self.img_shape[:-1] # convert R^2N to C^N self.ksp_shape = self.ksp_shape[:-1] # convert R^2N to C^N self._build_model_sigpy()
def main_infer(args): if args.recon == 'cgsense': MyRecon = CGSenseRecon elif args.recon == 'modl': MyRecon = MoDLRecon elif args.recon == 'resnet': MyRecon = ResNetRecon elif args.recon == 'dbp': MyRecon = DeepBasisPursuitRecon # load from checkpoint print('loading checkpoint: {}'.format(args.checkpoint_init)) M = MyRecon.load_from_checkpoint(args.checkpoint_init) # create a dataset # M.D = MultiChannelMRIDataset() # get number of datasets useBatchsizeOne = True if useBatchsizeOne: with h5py.File(args.data_file, 'r') as F: imgs = np.array(F['imgs'], dtype=np.complex) ndatasets = imgs.shape[0] for idx in range(ndatasets): print('%i/%i' % (idx, ndatasets)) # load data with h5py.File(args.data_file, 'r') as F: imgs = np.array(F['imgs'][idx, ...], dtype=np.complex) maps = np.array(F['maps'][idx, ...], dtype=np.complex) ksp = np.array(F['ksp'][idx, ...], dtype=np.complex) masks = np.array(F['masks'][idx, ...], dtype=np.float) if args.multiband: phi = np.array(F['phi'][idx, ...], dtype=np.complex) if len(imgs.shape) == 3 and args.multiband: imgs, maps, masks, ksp = imgs[None, ...], maps[ None, ...], masks[None, ...], ksp[None, ...] if args.multiband: phi = phi[None, ...] elif len(imgs.shape) == 2: imgs, maps, masks, ksp = imgs[None, ...], maps[ None, ...], masks[None, ...], ksp[None, ...] imgs_torch_pre = cp.c2r(imgs).astype(np.float32) maps_torch_pre = cp.c2r(maps).astype(np.float32) masks_torch_pre = masks.astype(np.float32) ksp_torch_pre = cp.c2r(ksp).astype(np.float32) if args.multiband: phi_torch_pre = cp.c2r(phi).astype(np.float32) print(imgs.shape) print(maps.shape) print(masks.shape) print(ksp.shape) if args.multiband: print(phi.shape) print('') print(imgs_torch_pre.shape) print(maps_torch_pre.shape) print(masks_torch_pre.shape) print(ksp_torch_pre.shape) if args.multiband: print(phi_torch_pre.shape) # store a data dictionary in memory if args.multiband: data = { 'imgs': sigpy.to_pytorch(imgs_torch_pre), 'maps': sigpy.to_pytorch(maps_torch_pre), 'masks': sigpy.to_pytorch(masks_torch_pre), 'phi': sigpy.to_pytorch(phi_torch_pre), 'out': sigpy.to_pytorch(ksp_torch_pre) } else: data = { 'imgs': sigpy.to_pytorch(imgs_torch_pre), 'maps': sigpy.to_pytorch(maps_torch_pre), 'masks': sigpy.to_pytorch(masks_torch_pre), 'out': sigpy.to_pytorch(ksp_torch_pre) } # batch this data M.batch(data) # predict output output = sigpy.from_pytorch(M(sigpy.to_pytorch(ksp_torch_pre))) # complexify output = output[..., 0] + 1j * output[..., 1] output = output[0, ...] # if args.multiband: # output = output[0,::,::,::,0] + 1j*output[0,::,::,::,1] # else: # output = output[0,::,::,0] + 1j*output[0,::,::,1] # allocate output if idx == 0: if args.multiband: pred = np.zeros((ndatasets, output.shape[0], output.shape[1], output.shape[2]), dtype=output.dtype) else: pred = np.zeros( (ndatasets, output.shape[0], output.shape[1]), dtype=output.dtype) if args.multiband: plt.figure() for slc in range(output.shape[0]): plt.subplot(1, output.shape[0], slc + 1) plt.imshow(np.abs(output[slc, ::, ::]), cmap='gray') plt.subplots_adjust(wspace=0) plt.show() else: plt.figure() plt.imshow(np.abs(output), cmap='gray') plt.show() # store in output if args.multiband: pred[idx, ::, ::, ::] = output else: pred[idx, ::, ::] = output np.savez('inference.npz', pred=pred) else: print('Reading from input file') with h5py.File(args.data_file, 'r') as F: imgs = np.array(F['imgs'], dtype=np.complex) maps = np.array(F['maps'], dtype=np.complex) ksp = np.array(F['ksp'], dtype=np.complex) masks = np.array(F['masks'], dtype=np.float) if args.multiband: phi = np.array(F['phi'], dtype=np.complex) if len(masks.shape) == 2: imgs, maps, masks, ksp = imgs[None, ...], maps[None, ...], masks[ None, ...], ksp[None, ...] if args.multiband: phi = phi[None, ...] imgs_torch_pre = cp.c2r(imgs).astype(np.float32) maps_torch_pre = cp.c2r(maps).astype(np.float32) masks_torch_pre = masks.astype(np.float32) ksp_torch_pre = cp.c2r(ksp).astype(np.float32) if args.multiband: phi_torch_pre = cp.c2r(phi).astype(np.float32) print(imgs.shape) print(maps.shape) print(masks.shape) print(ksp.shape) if args.multiband: print(phi.shape) print(imgs_torch_pre.shape) print(maps_torch_pre.shape) print(masks_torch_pre.shape) print(ksp_torch_pre.shape) if args.multiband: print(phi_torch_pre.shape) print('Writing to dictionary') if args.multiband: data = { 'imgs': sigpy.to_pytorch(imgs_torch_pre), 'maps': sigpy.to_pytorch(maps_torch_pre), 'masks': sigpy.to_pytorch(masks_torch_pre), 'phi': sigpy.to_pytorch(phi_torch_pre), 'out': sigpy.to_pytorch(ksp_torch_pre) } else: data = { 'imgs': sigpy.to_pytorch(imgs_torch_pre), 'maps': sigpy.to_pytorch(maps_torch_pre), 'masks': sigpy.to_pytorch(masks_torch_pre), 'out': sigpy.to_pytorch(ksp_torch_pre) } print('Preparing data batch') M.batch(data) print('Calling M(y)') pred = M(y) print(pred.shape) np.savez('inference.npz', pred=pred) return pred