def test_definition(self): with self.session(): x = [0, 1, 2, 3, 4, -4, -3, -2, -1] y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] self.assertAllEqual(fft_ops.fftshift(x), y) self.assertAllEqual(fft_ops.ifftshift(y), x) x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] self.assertAllEqual(fft_ops.fftshift(x), y) self.assertAllEqual(fft_ops.ifftshift(y), x)
def test_negative_axes(self): with self.session(): freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted) self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs) self.assertAllEqual( fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,))) self.assertAllEqual( fft_ops.ifftshift(shifted, axes=-1), fft_ops.ifftshift(shifted, axes=(1,)))
def test_axes_keyword(self): with self.session(): freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)), shifted) self.assertAllEqual(fft_ops.fftshift(freqs, axes=0), fft_ops.fftshift(freqs, axes=(0, ))) self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)), freqs) self.assertAllEqual(fft_ops.ifftshift(shifted, axes=0), fft_ops.ifftshift(shifted, axes=(0, ))) self.assertAllEqual(fft_ops.fftshift(freqs), shifted) self.assertAllEqual(fft_ops.ifftshift(shifted), freqs)
def tf_ortho_ifft2d(kspace, enable_multiprocessing=True): axes = [len(kspace.shape) - 2, len(kspace.shape) - 1] scaling_norm = tf.cast( tf.math.sqrt( tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')), kspace.dtype) if len(kspace.shape) == 4: # multicoil case ncoils = tf.shape(kspace)[1] n_slices = tf.shape(kspace)[0] k_shape_x = tf.shape(kspace)[-2] k_shape_y = tf.shape(kspace)[-1] shifted_kspace = ifftshift(kspace, axes=axes) if enable_multiprocessing: batched_shifted_kspace = tf.reshape(shifted_kspace, (-1, k_shape_x, k_shape_y)) batched_shifted_image = tf.map_fn( ifft2d, batched_shifted_kspace, parallel_iterations=multiprocessing.cpu_count(), ) if len(kspace.shape) == 4: # multicoil case image_shape = [n_slices, ncoils, k_shape_x, k_shape_y] elif len(kspace.shape) == 3: image_shape = [n_slices, k_shape_x, k_shape_y] else: image_shape = [k_shape_x, k_shape_y] shifted_image = tf.reshape(batched_shifted_image, image_shape) else: shifted_image = ifft2d(shifted_kspace) image = fftshift(shifted_image, axes=axes) return scaling_norm * image
def test_numpy_compatibility(self): with self.session(): x = [0, 1, 2, 3, 4, -4, -3, -2, -1] y = [-4, -3, -2, -1, 0, 1, 2, 3, 4] self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x)) self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y)) x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1] y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x)) self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y)) freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]] shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]] self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)), np.fft.fftshift(freqs, axes=(0, 1))) self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)), np.fft.ifftshift(shifted, axes=(0, 1)))
def ortho_fft2d(image): image = tf.cast(image, 'complex64') axes = [len(image.shape) - 2, len(image.shape) - 1] scaling_norm = tf.cast( tf.math.sqrt( tf.cast(tf.math.reduce_prod(tf.shape(image)[-2:]), 'float32')), image.dtype) if len(image.shape) == 4: # multicoil case ncoils = tf.shape(image)[1] n_slices = tf.shape(image)[0] i_shape_x = tf.shape(image)[-2] i_shape_y = tf.shape(image)[-1] shifted_image = fftshift(image, axes=axes) batched_shifted_image = tf.reshape(shifted_image, (-1, i_shape_x, i_shape_y)) batched_shifted_kspace = tf.map_fn( fft2d, batched_shifted_image, parallel_iterations=multiprocessing.cpu_count(), ) if len(image.shape) == 4: # multicoil case kspace_shape = [n_slices, ncoils, i_shape_x, i_shape_y] elif len(image.shape) == 3: kspace_shape = [n_slices, i_shape_x, i_shape_y] else: kspace_shape = [i_shape_x, i_shape_y] shifted_kspace = tf.reshape(batched_shifted_kspace, kspace_shape) kspace = ifftshift(shifted_kspace, axes=axes) return kspace / scaling_norm
def tf_unmasked_op(x, idx=0): scaling_norm = tf.dtypes.cast( tf.math.sqrt( tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[1:3]), 'float32')), x.dtype) return tf.expand_dims(ifftshift(fft2d(fftshift(x[..., idx], axes=[1, 2])), axes=[1, 2]), axis=-1) / scaling_norm
def tf_unmasked_adj_op(x, idx=0): axes = [len(x.shape) - 3, len(x.shape) - 2] scaling_norm = tf.dtypes.cast( tf.math.sqrt( tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[-3:-1]), 'float32')), x.dtype) return scaling_norm * tf.expand_dims(fftshift( ifft2d(ifftshift(x[..., idx], axes=axes)), axes=axes), axis=-1)
def ortho_ifft2d(kspace): kspace = _order_for_ft(kspace) shift_axes = [2, 3] scaling_norm = _compute_scaling_norm(kspace) shifted_kspace = ifftshift(kspace, axes=shift_axes) image_shifted = ifft2d(shifted_kspace) image_unnormed = fftshift(image_shifted, axes=shift_axes) image = image_unnormed * scaling_norm image = _order_after_ft(image) return image
def ortho_fft2d(image): image = _order_for_ft(image) shift_axes = [2, 3] scaling_norm = _compute_scaling_norm(image) shifted_image = fftshift(image, axes=shift_axes) kspace_shifted = fft2d(shifted_image) kspace_unnormed = ifftshift(kspace_shifted, axes=shift_axes) kspace = kspace_unnormed / scaling_norm kspace = _order_after_ft(kspace) return kspace
def test_placeholder(self, axes): if context.executing_eagerly(): return x = array_ops.placeholder(shape=[None, None, None], dtype="float32") y_fftshift = fft_ops.fftshift(x, axes=axes) y_ifftshift = fft_ops.ifftshift(x, axes=axes) x_np = np.random.rand(16, 256, 256) with self.session() as sess: y_fftshift_res, y_ifftshift_res = sess.run( [y_fftshift, y_ifftshift], feed_dict={x: x_np}) self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes)) self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
def testPlaceholder(self): x = array_ops.placeholder(shape=[None, None, None], dtype="float32") axes_to_test = [None, 1, [1, 2]] for axes in axes_to_test: y_fftshift = fft_ops.fftshift(x, axes=axes) y_ifftshift = fft_ops.ifftshift(x, axes=axes) with self.session() as sess: x_np = np.random.rand(16, 256, 256) y_fftshift_res, y_ifftshift_res = sess.run( [y_fftshift, y_ifftshift], feed_dict={x: x_np}, ) self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes)) self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6): n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32) center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32) low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32) low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32) ### # NOTE: the following stands for in numpy: # low_freq_mask = np.zeros_like(kspace) # low_freq_mask[ # ..., # low_freq_lower_locations[0]:low_freq_upper_locations[0], # low_freq_lower_locations[1]:low_freq_upper_locations[1] # ] = 1 x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0]) y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1]) X_range, Y_range = tf.meshgrid(x_range, y_range) X_range = tf.reshape(X_range, (-1,)) Y_range = tf.reshape(Y_range, (-1,)) low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1) # we have to transpose because only the first dimension can be indexed in # scatter_nd scatter_nd_perm = [2, 3, 0, 1] low_freq_mask = tf.scatter_nd( indices=low_freq_mask_indices, updates=tf.ones([ tf.size(X_range), tf.shape(kspace)[0], tf.shape(kspace)[1]], ), shape=[tf.shape(kspace)[i] for i in scatter_nd_perm], ) low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm) ### low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype) shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3]) coil_image_low_freq_shifted = ifft2d(shifted_kspace) coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3]) # no need to norm this since they all have the same norm low_freq_rss = tf.norm(coil_image_low_freq, axis=1) coil_smap = coil_image_low_freq / low_freq_rss[:, None] # for now we do not perform background removal based on low_freq_rss # could be done with 1D k-means or fixed background_thresh, with tf.where return coil_smap
def adj_op(self, inputs): if self.masked: if self.multicoil: kspace, mask, smaps = inputs else: kspace, mask = inputs kspace = _mask_tf([kspace, mask]) else: if self.multicoil: kspace, smaps = inputs else: kspace = inputs kspace = kspace[..., 0] scaling_norm = _compute_scaling_norm(kspace) shifted_kspace = ifftshift(kspace, axes=self.shift_axes) image_shifted = ifft2d(shifted_kspace) image_unnormed = fftshift(image_shifted, axes=self.shift_axes) image = image_unnormed * scaling_norm if self.multicoil: image = tf.reduce_sum(image * tf.math.conj(smaps), axis=1) image = image[..., None] return image
def op(self, inputs): if self.multicoil: if self.masked: image, mask, smaps = inputs else: image, smaps = inputs else: if self.masked: image, mask = inputs else: image = inputs image = image[..., 0] scaling_norm = _compute_scaling_norm(image) if self.multicoil: image = tf.expand_dims(image, axis=1) image = image * smaps shifted_image = fftshift(image, axes=self.shift_axes) kspace_shifted = fft2d(shifted_image) kspace_unnormed = ifftshift(kspace_shifted, axes=self.shift_axes) kspace = kspace_unnormed[..., None] / scaling_norm if self.masked: kspace = _mask_tf([kspace, mask]) return kspace
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6): """Extract raw sensitivity maps for kspaces This function will first select a low frequency region in all the kspaces, then Fourier invert it, and finally perform a normalisation by the root sum-of-square. kspace has to be of shape: nslices x ncoils x height x width Arguments: kspace (tf.Tensor): the kspace whose sensitivity maps you want extracted. low_freq_percentage (int): the low frequency region to consider for sensitivity maps extraction, given as a percentage of the width of the kspace. In fastMRI, it's 8 for an acceleration factor of 4, and 4 for an acceleration factor of 8. Defaults to 8. background_thresh (float): unused for now, will later allow to have thresholded sensitivity maps. Returns: tf.Tensor: extracted raw sensitivity maps. """ n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32) center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32) low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32) low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32) ### # NOTE: the following stands for in numpy: # low_freq_mask = np.zeros_like(kspace) # low_freq_mask[ # ..., # low_freq_lower_locations[0]:low_freq_upper_locations[0], # low_freq_lower_locations[1]:low_freq_upper_locations[1] # ] = 1 x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0]) y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1]) X_range, Y_range = tf.meshgrid(x_range, y_range) X_range = tf.reshape(X_range, (-1,)) Y_range = tf.reshape(Y_range, (-1,)) low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1) # we have to transpose because only the first dimension can be indexed in # scatter_nd scatter_nd_perm = [2, 3, 0, 1] low_freq_mask = tf.scatter_nd( indices=low_freq_mask_indices, updates=tf.ones([ tf.size(X_range), tf.shape(kspace)[0], tf.shape(kspace)[1]], ), shape=[tf.shape(kspace)[i] for i in scatter_nd_perm], ) low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm) ### low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype) shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3]) coil_image_low_freq_shifted = ifft2d(shifted_kspace) coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3]) # no need to norm this since they all have the same norm low_freq_rss = tf.norm(coil_image_low_freq, axis=1) coil_smap = coil_image_low_freq / low_freq_rss[:, None] # for now we do not perform background removal based on low_freq_rss # could be done with 1D k-means or fixed background_thresh, with tf.where return coil_smap