Пример #1
0
 def test_definition(self):
   with self.session():
     x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
     y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
     self.assertAllEqual(fft_ops.fftshift(x), y)
     self.assertAllEqual(fft_ops.ifftshift(y), x)
     x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
     y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
     self.assertAllEqual(fft_ops.fftshift(x), y)
     self.assertAllEqual(fft_ops.ifftshift(y), x)
Пример #2
0
 def test_negative_axes(self):
   with self.session():
     freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
     shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
     self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, -1)), shifted)
     self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, -1)), freqs)
     self.assertAllEqual(
         fft_ops.fftshift(freqs, axes=-1), fft_ops.fftshift(freqs, axes=(1,)))
     self.assertAllEqual(
         fft_ops.ifftshift(shifted, axes=-1),
         fft_ops.ifftshift(shifted, axes=(1,)))
Пример #3
0
 def test_axes_keyword(self):
     with self.session():
         freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
         shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
         self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)), shifted)
         self.assertAllEqual(fft_ops.fftshift(freqs, axes=0),
                             fft_ops.fftshift(freqs, axes=(0, )))
         self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)), freqs)
         self.assertAllEqual(fft_ops.ifftshift(shifted, axes=0),
                             fft_ops.ifftshift(shifted, axes=(0, )))
         self.assertAllEqual(fft_ops.fftshift(freqs), shifted)
         self.assertAllEqual(fft_ops.ifftshift(shifted), freqs)
def tf_ortho_ifft2d(kspace, enable_multiprocessing=True):
    axes = [len(kspace.shape) - 2, len(kspace.shape) - 1]
    scaling_norm = tf.cast(
        tf.math.sqrt(
            tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')),
        kspace.dtype)
    if len(kspace.shape) == 4:
        # multicoil case
        ncoils = tf.shape(kspace)[1]
    n_slices = tf.shape(kspace)[0]
    k_shape_x = tf.shape(kspace)[-2]
    k_shape_y = tf.shape(kspace)[-1]
    shifted_kspace = ifftshift(kspace, axes=axes)
    if enable_multiprocessing:
        batched_shifted_kspace = tf.reshape(shifted_kspace,
                                            (-1, k_shape_x, k_shape_y))
        batched_shifted_image = tf.map_fn(
            ifft2d,
            batched_shifted_kspace,
            parallel_iterations=multiprocessing.cpu_count(),
        )
        if len(kspace.shape) == 4:
            # multicoil case
            image_shape = [n_slices, ncoils, k_shape_x, k_shape_y]
        elif len(kspace.shape) == 3:
            image_shape = [n_slices, k_shape_x, k_shape_y]
        else:
            image_shape = [k_shape_x, k_shape_y]
        shifted_image = tf.reshape(batched_shifted_image, image_shape)
    else:
        shifted_image = ifft2d(shifted_kspace)
    image = fftshift(shifted_image, axes=axes)
    return scaling_norm * image
Пример #5
0
 def test_numpy_compatibility(self):
     with self.session():
         x = [0, 1, 2, 3, 4, -4, -3, -2, -1]
         y = [-4, -3, -2, -1, 0, 1, 2, 3, 4]
         self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x))
         self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y))
         x = [0, 1, 2, 3, 4, -5, -4, -3, -2, -1]
         y = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
         self.assertAllEqual(fft_ops.fftshift(x), np.fft.fftshift(x))
         self.assertAllEqual(fft_ops.ifftshift(y), np.fft.ifftshift(y))
         freqs = [[0, 1, 2], [3, 4, -4], [-3, -2, -1]]
         shifted = [[-1, -3, -2], [2, 0, 1], [-4, 3, 4]]
         self.assertAllEqual(fft_ops.fftshift(freqs, axes=(0, 1)),
                             np.fft.fftshift(freqs, axes=(0, 1)))
         self.assertAllEqual(fft_ops.ifftshift(shifted, axes=(0, 1)),
                             np.fft.ifftshift(shifted, axes=(0, 1)))
Пример #6
0
def ortho_fft2d(image):
    image = tf.cast(image, 'complex64')
    axes = [len(image.shape) - 2, len(image.shape) - 1]
    scaling_norm = tf.cast(
        tf.math.sqrt(
            tf.cast(tf.math.reduce_prod(tf.shape(image)[-2:]), 'float32')),
        image.dtype)
    if len(image.shape) == 4:
        # multicoil case
        ncoils = tf.shape(image)[1]
    n_slices = tf.shape(image)[0]
    i_shape_x = tf.shape(image)[-2]
    i_shape_y = tf.shape(image)[-1]
    shifted_image = fftshift(image, axes=axes)
    batched_shifted_image = tf.reshape(shifted_image,
                                       (-1, i_shape_x, i_shape_y))
    batched_shifted_kspace = tf.map_fn(
        fft2d,
        batched_shifted_image,
        parallel_iterations=multiprocessing.cpu_count(),
    )
    if len(image.shape) == 4:
        # multicoil case
        kspace_shape = [n_slices, ncoils, i_shape_x, i_shape_y]
    elif len(image.shape) == 3:
        kspace_shape = [n_slices, i_shape_x, i_shape_y]
    else:
        kspace_shape = [i_shape_x, i_shape_y]
    shifted_kspace = tf.reshape(batched_shifted_kspace, kspace_shape)
    kspace = ifftshift(shifted_kspace, axes=axes)
    return kspace / scaling_norm
Пример #7
0
def tf_unmasked_op(x, idx=0):
    scaling_norm = tf.dtypes.cast(
        tf.math.sqrt(
            tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[1:3]), 'float32')),
        x.dtype)
    return tf.expand_dims(ifftshift(fft2d(fftshift(x[..., idx], axes=[1, 2])),
                                    axes=[1, 2]),
                          axis=-1) / scaling_norm
Пример #8
0
def tf_unmasked_adj_op(x, idx=0):
    axes = [len(x.shape) - 3, len(x.shape) - 2]
    scaling_norm = tf.dtypes.cast(
        tf.math.sqrt(
            tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[-3:-1]),
                           'float32')), x.dtype)
    return scaling_norm * tf.expand_dims(fftshift(
        ifft2d(ifftshift(x[..., idx], axes=axes)), axes=axes),
                                         axis=-1)
def ortho_ifft2d(kspace):
    kspace = _order_for_ft(kspace)
    shift_axes = [2, 3]
    scaling_norm = _compute_scaling_norm(kspace)
    shifted_kspace = ifftshift(kspace, axes=shift_axes)
    image_shifted = ifft2d(shifted_kspace)
    image_unnormed = fftshift(image_shifted, axes=shift_axes)
    image = image_unnormed * scaling_norm
    image = _order_after_ft(image)
    return image
def ortho_fft2d(image):
    image = _order_for_ft(image)
    shift_axes = [2, 3]
    scaling_norm = _compute_scaling_norm(image)
    shifted_image = fftshift(image, axes=shift_axes)
    kspace_shifted = fft2d(shifted_image)
    kspace_unnormed = ifftshift(kspace_shifted, axes=shift_axes)
    kspace = kspace_unnormed / scaling_norm
    kspace = _order_after_ft(kspace)
    return kspace
Пример #11
0
 def test_placeholder(self, axes):
     if context.executing_eagerly():
         return
     x = array_ops.placeholder(shape=[None, None, None], dtype="float32")
     y_fftshift = fft_ops.fftshift(x, axes=axes)
     y_ifftshift = fft_ops.ifftshift(x, axes=axes)
     x_np = np.random.rand(16, 256, 256)
     with self.session() as sess:
         y_fftshift_res, y_ifftshift_res = sess.run(
             [y_fftshift, y_ifftshift], feed_dict={x: x_np})
     self.assertAllClose(y_fftshift_res, np.fft.fftshift(x_np, axes=axes))
     self.assertAllClose(y_ifftshift_res, np.fft.ifftshift(x_np, axes=axes))
Пример #12
0
 def testPlaceholder(self):
     x = array_ops.placeholder(shape=[None, None, None], dtype="float32")
     axes_to_test = [None, 1, [1, 2]]
     for axes in axes_to_test:
         y_fftshift = fft_ops.fftshift(x, axes=axes)
         y_ifftshift = fft_ops.ifftshift(x, axes=axes)
         with self.session() as sess:
             x_np = np.random.rand(16, 256, 256)
             y_fftshift_res, y_ifftshift_res = sess.run(
                 [y_fftshift, y_ifftshift],
                 feed_dict={x: x_np},
             )
             self.assertAllClose(y_fftshift_res,
                                 np.fft.fftshift(x_np, axes=axes))
             self.assertAllClose(y_ifftshift_res,
                                 np.fft.ifftshift(x_np, axes=axes))
Пример #13
0
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6):
    n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32)
    center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32)
    low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32)
    low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32)
    ###
    # NOTE: the following stands for in numpy:
    # low_freq_mask = np.zeros_like(kspace)
    # low_freq_mask[
    #     ...,
    #     low_freq_lower_locations[0]:low_freq_upper_locations[0],
    #     low_freq_lower_locations[1]:low_freq_upper_locations[1]
    # ] = 1
    x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0])
    y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1])
    X_range, Y_range = tf.meshgrid(x_range, y_range)
    X_range = tf.reshape(X_range, (-1,))
    Y_range = tf.reshape(Y_range, (-1,))
    low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1)
    # we have to transpose because only the first dimension can be indexed in
    # scatter_nd
    scatter_nd_perm = [2, 3, 0, 1]
    low_freq_mask = tf.scatter_nd(
        indices=low_freq_mask_indices,
        updates=tf.ones([
            tf.size(X_range),
            tf.shape(kspace)[0],
            tf.shape(kspace)[1]],
        ),
        shape=[tf.shape(kspace)[i] for i in scatter_nd_perm],
    )
    low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm)
    ###
    low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype)
    shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3])
    coil_image_low_freq_shifted = ifft2d(shifted_kspace)
    coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3])
    # no need to norm this since they all have the same norm
    low_freq_rss = tf.norm(coil_image_low_freq, axis=1)
    coil_smap = coil_image_low_freq / low_freq_rss[:, None]
    # for now we do not perform background removal based on low_freq_rss
    # could be done with 1D k-means or fixed background_thresh, with tf.where
    return coil_smap
 def adj_op(self, inputs):
     if self.masked:
         if self.multicoil:
             kspace, mask, smaps = inputs
         else:
             kspace, mask = inputs
         kspace = _mask_tf([kspace, mask])
     else:
         if self.multicoil:
             kspace, smaps = inputs
         else:
             kspace = inputs
     kspace = kspace[..., 0]
     scaling_norm = _compute_scaling_norm(kspace)
     shifted_kspace = ifftshift(kspace, axes=self.shift_axes)
     image_shifted = ifft2d(shifted_kspace)
     image_unnormed = fftshift(image_shifted, axes=self.shift_axes)
     image = image_unnormed * scaling_norm
     if self.multicoil:
         image = tf.reduce_sum(image * tf.math.conj(smaps), axis=1)
     image = image[..., None]
     return image
 def op(self, inputs):
     if self.multicoil:
         if self.masked:
             image, mask, smaps = inputs
         else:
             image, smaps = inputs
     else:
         if self.masked:
             image, mask = inputs
         else:
             image = inputs
     image = image[..., 0]
     scaling_norm = _compute_scaling_norm(image)
     if self.multicoil:
         image = tf.expand_dims(image, axis=1)
         image = image * smaps
     shifted_image = fftshift(image, axes=self.shift_axes)
     kspace_shifted = fft2d(shifted_image)
     kspace_unnormed = ifftshift(kspace_shifted, axes=self.shift_axes)
     kspace = kspace_unnormed[..., None] / scaling_norm
     if self.masked:
         kspace = _mask_tf([kspace, mask])
     return kspace
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6):
    """Extract raw sensitivity maps for kspaces

    This function will first select a low frequency region in all the kspaces,
    then Fourier invert it, and finally perform a normalisation by the root
    sum-of-square.
    kspace has to be of shape: nslices x ncoils x height x width

    Arguments:
        kspace (tf.Tensor): the kspace whose sensitivity maps you want extracted.
        low_freq_percentage (int): the low frequency region to consider for
            sensitivity maps extraction, given as a percentage of the width of
            the kspace. In fastMRI, it's 8 for an acceleration factor of 4, and
            4 for an acceleration factor of 8. Defaults to 8.
        background_thresh (float): unused for now, will later allow to have
            thresholded sensitivity maps.

    Returns:
        tf.Tensor: extracted raw sensitivity maps.
    """
    n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32)
    center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32)
    low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32)
    low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32)
    ###
    # NOTE: the following stands for in numpy:
    # low_freq_mask = np.zeros_like(kspace)
    # low_freq_mask[
    #     ...,
    #     low_freq_lower_locations[0]:low_freq_upper_locations[0],
    #     low_freq_lower_locations[1]:low_freq_upper_locations[1]
    # ] = 1
    x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0])
    y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1])
    X_range, Y_range = tf.meshgrid(x_range, y_range)
    X_range = tf.reshape(X_range, (-1,))
    Y_range = tf.reshape(Y_range, (-1,))
    low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1)
    # we have to transpose because only the first dimension can be indexed in
    # scatter_nd
    scatter_nd_perm = [2, 3, 0, 1]
    low_freq_mask = tf.scatter_nd(
        indices=low_freq_mask_indices,
        updates=tf.ones([
            tf.size(X_range),
            tf.shape(kspace)[0],
            tf.shape(kspace)[1]],
        ),
        shape=[tf.shape(kspace)[i] for i in scatter_nd_perm],
    )
    low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm)
    ###
    low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype)
    shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3])
    coil_image_low_freq_shifted = ifft2d(shifted_kspace)
    coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3])
    # no need to norm this since they all have the same norm
    low_freq_rss = tf.norm(coil_image_low_freq, axis=1)
    coil_smap = coil_image_low_freq / low_freq_rss[:, None]
    # for now we do not perform background removal based on low_freq_rss
    # could be done with 1D k-means or fixed background_thresh, with tf.where
    return coil_smap