def test_adjoint_and_gradients(im_size, batch_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size) * 2) im_rank = len(im_size) M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory ktraj_ori = tf.Variable( tf.random.uniform( (batch_size, im_rank, M), minval=-1 / 2, maxval=1 / 2) * 2 * np.pi) # Have a random signal signal = tf.Variable( tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64)) kdata = tf.Variable( kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)) Idata = tf.Variable( kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori)) ktraj_noise = np.copy(ktraj_ori) ktraj_noise += 0.01 * tf.Variable( tf.random.uniform( (batch_size, im_rank, M), minval=-1 / 2, maxval=1 / 2) * 2 * np.pi) ktraj = tf.Variable(ktraj_noise) with tf.GradientTape(persistent=True) as g: I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj) A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True) I_ndft = tf.reshape(tf.transpose(tf.matmul(kdata, A), [0, 1, 2]), (batch_size, 1, *im_size)) loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2) loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same tf_test.assertAllClose(I_nufft, I_ndft, atol=2e-3) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3) # Test gradients in chain rule with respect to ktraj gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0] tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4)
def test_forward_gradient(): traj = ktraj_function() image = tf.zeros([1, 1, *image_shape], dtype=tf.complex64) nufft_ob = KbNufftModule( im_size=(640, 400), grid_size=None, norm='ortho', ) forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob()) with tf.GradientTape() as tape: tape.watch(image) res = forward_op(image, traj) grad = tape.gradient(res, image) tf_test = tf.test.TestCase() tf_test.assertEqual(grad.shape, image.shape)
def test_adjoint_gradient(): traj = ktraj_function() kspace = tf.zeros([1, 1, kspace_shape], dtype=tf.complex64) nufft_ob = KbNufftModule( im_size=(640, 400), grid_size=None, norm='ortho', ) backward_op = kbnufft_adjoint(nufft_ob._extract_nufft_interpob()) with tf.GradientTape() as tape: tape.watch(kspace) res = backward_op(kspace, traj) grad = tape.gradient(res, kspace) tf_test = tf.test.TestCase() tf_test.assertEqual(grad.shape, kspace.shape)
def test_ncpdnet_init_and_call_3d(dcomp, volume_shape): model = NCPDNet( n_iter=1, n_primal=2, n_filters=2, multicoil=False, im_size=volume_shape, three_d=True, dcomp=dcomp, fastmri=False, ) af = 16 traj = get_stacks_of_radial_trajectory(volume_shape, af=af) spokelength = volume_shape[-2] nspokes = volume_shape[-1] // af nstacks = volume_shape[0] kspace_shape = nspokes * spokelength * nstacks extra_args = (tf.constant([volume_shape]), ) if dcomp: nufft_ob = KbNufftModule( im_size=volume_shape, grid_size=None, norm='ortho', ) interpob = nufft_ob._extract_nufft_interpob() nufftob_forw = kbnufft_forward(interpob) nufftob_back = kbnufft_adjoint(interpob) dcomp = calculate_radial_dcomp_tf( interpob, nufftob_forw, nufftob_back, traj[0], stacks=True, ) dcomp = tf.ones([1, tf.shape(dcomp)[0]], dtype=dcomp.dtype) * dcomp[None, :] extra_args += (dcomp, ) res = model([ tf.zeros([1, 1, kspace_shape, 1], dtype=tf.complex64), traj, extra_args, ]) assert res.shape[1:4] == volume_shape
def profile_tfkbnufft( image, ktraj, im_size, device, ): if device == 'CPU': num_nuffts = 20 else: num_nuffts = 50 print(f'Using {device}') device_name = f'/{device}:0' with tf.device(device_name): image = tf.constant(image) if device == 'GPU': image = tf.cast(image, tf.complex64) ktraj = tf.constant(ktraj) nufft_ob = KbNufftModule(im_size=im_size, grid_size=None, norm='ortho') forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob()) adjoint_op = kbnufft_adjoint(nufft_ob._extract_nufft_interpob()) # warm-up computation for _ in range(2): y = forward_op(image, ktraj) start_time = time.perf_counter() for _ in range(num_nuffts): y = forward_op(image, ktraj) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts print('forward average time: {}'.format(avg_time)) # warm-up computation for _ in range(2): x = adjoint_op(y, ktraj) # run the adjoint speed tests start_time = time.perf_counter() for _ in range(num_nuffts): x = adjoint_op(y, ktraj) end_time = time.perf_counter() avg_time = (end_time - start_time) / num_nuffts print('backward average time: {}'.format(avg_time))
class NFFTBase(Layer): def __init__(self, multicoil=False, im_size=(640, 472), density_compensation=False, **kwargs): super(NFFTBase, self).__init__(**kwargs) self.multicoil = multicoil self.im_size = im_size self.nufft_ob = KbNufftModule( im_size=im_size, grid_size=None, norm='ortho', ) self.density_compensation = density_compensation self.forward_op = kbnufft_forward( self.nufft_ob._extract_nufft_interpob()) self.backward_op = kbnufft_adjoint( self.nufft_ob._extract_nufft_interpob()) def pad_for_nufft(self, image): return _pad_for_nufft(image, self.im_size) def crop_for_pad(self, image, shape): return _crop_for_pad(image, shape, self.im_size) def crop_for_nufft(self, image): return _crop_for_nufft(image, self.im_size) def op(self, inputs): if self.multicoil: image, ktraj, smaps = inputs else: image, ktraj = inputs # for tfkbnufft we need a coil dimension even if there is none image = image[:, None, ..., 0] if self.multicoil: image = image * smaps kspace = nufft(self.nufft_ob, image, ktraj, image_size=self.im_size) # TODO: get rid of shape return as not needed in the end. # shape is computed once in the preprocessing and passed on as is. shape = tf.ones([tf.shape(image)[0]], dtype=tf.int32) * tf.shape(image)[-1] return kspace[..., None], [shape] def adj_op(self, inputs): if self.multicoil: if self.density_compensation: kspace, ktraj, smaps, shape, dcomp, = inputs else: kspace, ktraj, smaps, shape = inputs else: if self.density_compensation: kspace, ktraj, shape, dcomp = inputs else: kspace, ktraj, shape = inputs if self.density_compensation: kspace = tf.cast(dcomp, kspace.dtype) * kspace[..., 0] else: kspace = kspace[..., 0] image = self.backward_op(kspace, ktraj) ## image resizing if len(self.im_size) < 3: # NOTE: for now very ugly way to deal with this condition shape = tf.reshape(shape[0], []) reshaping_condition = tf.math.less(shape, self.im_size[-1]) else: shape = shape[0] reshaping_condition = tf.reduce_any( tf.math.less(shape, self.im_size)) image_reshaped = tf.cond( pred=reshaping_condition, true_fn=lambda: self.crop_for_pad(image, shape), false_fn=lambda: image, ) if self.multicoil: image = tf.reduce_sum(image_reshaped * tf.math.conj(smaps), axis=1) else: image = image_reshaped[:, 0] image = image[..., None] return image
class NFFTBase(Layer): def __init__(self, multicoil=False, im_size=(640, 472), density_compensation=False, **kwargs): super(NFFTBase, self).__init__(**kwargs) self.multicoil = multicoil self.im_size = im_size self.nufft_ob = KbNufftModule( im_size=im_size, grid_size=None, norm='ortho', ) self.density_compensation = density_compensation self.forward_op = kbnufft_forward(self.nufft_ob._extract_nufft_interpob()) self.backward_op = kbnufft_adjoint(self.nufft_ob._extract_nufft_interpob()) def pad_for_nufft(self, image): return _pad_for_nufft(image, self.im_size) def crop_for_pad(self, image, shape): return _crop_for_pad(image, shape, self.im_size) def crop_for_nufft(self, image): return _crop_for_nufft(image, self.im_size) def op(self, inputs): if self.multicoil: image, ktraj, smaps = inputs else: image, ktraj = inputs # for tfkbnufft we need a coil dimension even if there is none image = image[:, None, ..., 0] if self.multicoil: image = image * smaps kspace = nufft(self.nufft_ob, image, ktraj, image_size=self.im_size) shape = tf.ones([tf.shape(image)[0]], dtype=tf.int32) * tf.shape(image)[-1] return kspace[..., None], [shape] def adj_op(self, inputs): if self.multicoil: if self.density_compensation: kspace, ktraj, smaps, shape, dcomp, = inputs else: kspace, ktraj, smaps, shape = inputs else: if self.density_compensation: kspace, ktraj, shape, dcomp = inputs else: kspace, ktraj, shape = inputs shape = tf.reshape(shape[0], []) if self.density_compensation: kspace = tf.cast(dcomp, kspace.dtype) * kspace[..., 0] else: kspace = kspace[..., 0] image = self.backward_op(kspace, ktraj) image_reshaped = tf.cond( tf.math.greater_equal(shape, self.im_size[-1]), lambda: image, lambda: self.crop_for_pad(image, shape), ) if self.multicoil: image = tf.reduce_sum(image_reshaped * tf.math.conj(smaps), axis=1) else: image = image_reshaped[:, 0] image = image[..., None] return image
class NonCartesianFastMRIDatasetBuilder(FastMRIDatasetBuilder): def __init__( self, image_size=IMAGE_SIZE, acq_type='radial', dcomp=True, scale_factor=1e6, traj=None, crop_image_data=False, **kwargs, ): self.image_size = image_size self.acq_type = acq_type self.traj = traj self._check_acq_type() self.dcomp = dcomp self.scale_factor = scale_factor self.crop_image_data = crop_image_data self.nufft_obj = KbNufftModule( im_size=self.image_size, grid_size=None, norm='ortho', ) super(NonCartesianFastMRIDatasetBuilder, self).__init__( **kwargs, ) if self.brain: raise ValueError( 'Currently the non cartesian data works only with knee data.') self._check_mode() self._check_dcomp_multicoil() def _check_acq_type(self,): if self.acq_type not in ['spiral', 'radial', 'cartesian_debug', 'other']: raise ValueError( f'acq_type must be spiral, radial or cartesian_debug but is {self.acq_type}' ) if self.acq_type == 'other' and self.traj is None: raise ValueError( f'Please provide a trajectory as input in case `acq_type` is `other`' ) def _check_mode(self,): if self.mode == 'test': raise ValueError('NonCartesian dataset cannot be used in test mode') def _check_dcomp_multicoil(self,): if self.multicoil and not self.dcomp: raise ValueError('You must use density compensation when in multicoil') def generate_trajectory(self,): if self.acq_type == 'radial': traj = get_radial_trajectory(self.image_size, af=self.af) elif self.acq_type == 'cartesian': traj = get_debugging_cartesian_trajectory() elif self.acq_type == 'spiral': traj = get_spiral_trajectory(self.image_size, af=self.af) elif self.acq_type == 'other': traj = self.traj return traj def preprocessing(self, image, kspace): traj = self.generate_trajectory() interpob = self.nufft_obj._extract_nufft_interpob() nufftob_forw = kbnufft_forward(interpob, multiprocessing=True) nufftob_back = kbnufft_adjoint(interpob, multiprocessing=True) if self.dcomp: dcomp = calculate_density_compensator( interpob, nufftob_forw, nufftob_back, traj[0], ) traj = tf.repeat(traj, tf.shape(image)[0], axis=0) orig_image_channels = ortho_ifft2d(kspace) if self.crop_image_data: image = adjust_image_size(image, self.image_size) nc_kspace = nufft(self.nufft_obj, orig_image_channels, traj, self.image_size, multicoil=self.multicoil) nc_kspace, image = scale_tensors(nc_kspace, image, scale_factor=self.scale_factor) image = image[..., None] nc_kspaces_channeled = nc_kspace[..., None] orig_shape = tf.ones([tf.shape(kspace)[0]], dtype=tf.int32) * self.image_size[-1] if not self.crop_image_data: output_shape = tf.shape(image)[1:][None, :] output_shape = tf.tile(output_shape, [tf.shape(image)[0], 1]) extra_args = (orig_shape,) if self.dcomp: dcomp = tf.ones( [tf.shape(kspace)[0], tf.shape(dcomp)[0]], dtype=dcomp.dtype, ) * dcomp[None, :] extra_args += (dcomp,) model_inputs = (nc_kspaces_channeled, traj) if self.multicoil: smaps = non_cartesian_extract_smaps(nc_kspace, traj, dcomp, nufftob_back, self.image_size) model_inputs += (smaps,) if not self.crop_image_data: model_inputs += (output_shape,) model_inputs += (extra_args,) return model_inputs, image