def tf_ortho_ifft2d(kspace, enable_multiprocessing=True):
    axes = [len(kspace.shape) - 2, len(kspace.shape) - 1]
    scaling_norm = tf.cast(
        tf.math.sqrt(
            tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')),
        kspace.dtype)
    if len(kspace.shape) == 4:
        # multicoil case
        ncoils = tf.shape(kspace)[1]
    n_slices = tf.shape(kspace)[0]
    k_shape_x = tf.shape(kspace)[-2]
    k_shape_y = tf.shape(kspace)[-1]
    shifted_kspace = ifftshift(kspace, axes=axes)
    if enable_multiprocessing:
        batched_shifted_kspace = tf.reshape(shifted_kspace,
                                            (-1, k_shape_x, k_shape_y))
        batched_shifted_image = tf.map_fn(
            ifft2d,
            batched_shifted_kspace,
            parallel_iterations=multiprocessing.cpu_count(),
        )
        if len(kspace.shape) == 4:
            # multicoil case
            image_shape = [n_slices, ncoils, k_shape_x, k_shape_y]
        elif len(kspace.shape) == 3:
            image_shape = [n_slices, k_shape_x, k_shape_y]
        else:
            image_shape = [k_shape_x, k_shape_y]
        shifted_image = tf.reshape(batched_shifted_image, image_shape)
    else:
        shifted_image = ifft2d(shifted_kspace)
    image = fftshift(shifted_image, axes=axes)
    return scaling_norm * image
예제 #2
0
  def operator_and_matrix(self,
                          shape_info,
                          dtype,
                          use_placeholder,
                          ensure_self_adjoint_and_pd=False):
    shape = shape_info.shape
    # For this test class, we are creating Hermitian spectrums.
    # We also want the spectrum to have eigenvalues bounded away from zero.
    #
    # pre_spectrum is bounded away from zero.
    pre_spectrum = linear_operator_test_util.random_uniform(
        shape=self._shape_to_spectrum_shape(shape),
        dtype=dtype,
        minval=1.,
        maxval=2.)
    pre_spectrum_c = _to_complex(pre_spectrum)

    # Real{IFFT[pre_spectrum]}
    #  = IFFT[EvenPartOf[pre_spectrum]]
    # is the IFFT of something that is also bounded away from zero.
    # Therefore, FFT[pre_h] would be a well-conditioned spectrum.
    pre_h = fft_ops.ifft2d(pre_spectrum_c)

    # A spectrum is Hermitian iff it is the DFT of a real convolution kernel.
    # So we will make spectrum = FFT[h], for real valued h.
    h = math_ops.real(pre_h)
    h_c = _to_complex(h)

    spectrum = fft_ops.fft2d(h_c)

    lin_op_spectrum = spectrum

    if use_placeholder:
      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)

    operator = linalg.LinearOperatorCirculant2D(
        lin_op_spectrum,
        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
        input_output_dtype=dtype)

    self.assertEqual(
        operator.parameters,
        {
            "input_output_dtype": dtype,
            "is_non_singular": None,
            "is_positive_definite": (
                True if ensure_self_adjoint_and_pd else None),
            "is_self_adjoint": (
                True if ensure_self_adjoint_and_pd else None),
            "is_square": True,
            "name": "LinearOperatorCirculant2D",
            "spectrum": lin_op_spectrum,
        })

    mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)

    return operator, mat
예제 #3
0
def tf_unmasked_adj_op(x, idx=0):
    axes = [len(x.shape) - 3, len(x.shape) - 2]
    scaling_norm = tf.dtypes.cast(
        tf.math.sqrt(
            tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[-3:-1]),
                           'float32')), x.dtype)
    return scaling_norm * tf.expand_dims(fftshift(
        ifft2d(ifftshift(x[..., idx], axes=axes)), axes=axes),
                                         axis=-1)
def ortho_ifft2d(kspace):
    kspace = _order_for_ft(kspace)
    shift_axes = [2, 3]
    scaling_norm = _compute_scaling_norm(kspace)
    shifted_kspace = ifftshift(kspace, axes=shift_axes)
    image_shifted = ifft2d(shifted_kspace)
    image_unnormed = fftshift(image_shifted, axes=shift_axes)
    image = image_unnormed * scaling_norm
    image = _order_after_ft(image)
    return image
예제 #5
0
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6):
    n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32)
    center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32)
    low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32)
    low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32)
    ###
    # NOTE: the following stands for in numpy:
    # low_freq_mask = np.zeros_like(kspace)
    # low_freq_mask[
    #     ...,
    #     low_freq_lower_locations[0]:low_freq_upper_locations[0],
    #     low_freq_lower_locations[1]:low_freq_upper_locations[1]
    # ] = 1
    x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0])
    y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1])
    X_range, Y_range = tf.meshgrid(x_range, y_range)
    X_range = tf.reshape(X_range, (-1,))
    Y_range = tf.reshape(Y_range, (-1,))
    low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1)
    # we have to transpose because only the first dimension can be indexed in
    # scatter_nd
    scatter_nd_perm = [2, 3, 0, 1]
    low_freq_mask = tf.scatter_nd(
        indices=low_freq_mask_indices,
        updates=tf.ones([
            tf.size(X_range),
            tf.shape(kspace)[0],
            tf.shape(kspace)[1]],
        ),
        shape=[tf.shape(kspace)[i] for i in scatter_nd_perm],
    )
    low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm)
    ###
    low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype)
    shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3])
    coil_image_low_freq_shifted = ifft2d(shifted_kspace)
    coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3])
    # no need to norm this since they all have the same norm
    low_freq_rss = tf.norm(coil_image_low_freq, axis=1)
    coil_smap = coil_image_low_freq / low_freq_rss[:, None]
    # for now we do not perform background removal based on low_freq_rss
    # could be done with 1D k-means or fixed background_thresh, with tf.where
    return coil_smap
  def operator_and_matrix(
      self, shape_info, dtype, use_placeholder,
      ensure_self_adjoint_and_pd=False):
    shape = shape_info.shape
    # For this test class, we are creating Hermitian spectrums.
    # We also want the spectrum to have eigenvalues bounded away from zero.
    #
    # pre_spectrum is bounded away from zero.
    pre_spectrum = linear_operator_test_util.random_uniform(
        shape=self._shape_to_spectrum_shape(shape),
        dtype=dtype,
        minval=1.,
        maxval=2.)
    pre_spectrum_c = _to_complex(pre_spectrum)

    # Real{IFFT[pre_spectrum]}
    #  = IFFT[EvenPartOf[pre_spectrum]]
    # is the IFFT of something that is also bounded away from zero.
    # Therefore, FFT[pre_h] would be a well-conditioned spectrum.
    pre_h = fft_ops.ifft2d(pre_spectrum_c)

    # A spectrum is Hermitian iff it is the DFT of a real convolution kernel.
    # So we will make spectrum = FFT[h], for real valued h.
    h = math_ops.real(pre_h)
    h_c = _to_complex(h)

    spectrum = fft_ops.fft2d(h_c)

    lin_op_spectrum = spectrum

    if use_placeholder:
      lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)

    operator = linalg.LinearOperatorCirculant2D(
        lin_op_spectrum,
        is_positive_definite=True if ensure_self_adjoint_and_pd else None,
        is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
        input_output_dtype=dtype)

    mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)

    return operator, mat
 def adj_op(self, inputs):
     if self.masked:
         if self.multicoil:
             kspace, mask, smaps = inputs
         else:
             kspace, mask = inputs
         kspace = _mask_tf([kspace, mask])
     else:
         if self.multicoil:
             kspace, smaps = inputs
         else:
             kspace = inputs
     kspace = kspace[..., 0]
     scaling_norm = _compute_scaling_norm(kspace)
     shifted_kspace = ifftshift(kspace, axes=self.shift_axes)
     image_shifted = ifft2d(shifted_kspace)
     image_unnormed = fftshift(image_shifted, axes=self.shift_axes)
     image = image_unnormed * scaling_norm
     if self.multicoil:
         image = tf.reduce_sum(image * tf.math.conj(smaps), axis=1)
     image = image[..., None]
     return image
    def _spectrum_to_circulant_2d(self, spectrum, shape, dtype):
        """Creates a block circulant matrix from a spectrum.

    Intentionally done in an explicit yet inefficient way.  This provides a
    cross check to the main code that uses fancy reshapes.

    Args:
      spectrum: Float or complex `Tensor`.
      shape:  Python list.  Desired shape of returned matrix.
      dtype:  Type to cast the returned matrix to.

    Returns:
      Block circulant (batch) matrix of desired `dtype`.
    """
        spectrum = _to_complex(spectrum)
        spectrum_shape = self._shape_to_spectrum_shape(shape)
        domain_dimension = spectrum_shape[-1]
        if not domain_dimension:
            return array_ops.zeros(shape, dtype)

        block_shape = spectrum_shape[-2:]

        # Explicitly compute the action of spectrum on basis vectors.
        matrix_rows = []
        for n0 in range(block_shape[0]):
            for n1 in range(block_shape[1]):
                x = np.zeros(block_shape)
                # x is a basis vector.
                x[n0, n1] = 1.0
                fft_x = fft_ops.fft2d(math_ops.cast(x, spectrum.dtype))
                h_convolve_x = fft_ops.ifft2d(spectrum * fft_x)
                # We want the flat version of the action of the operator on a basis
                # vector, not the block version.
                h_convolve_x = array_ops.reshape(h_convolve_x, shape[:-1])
                matrix_rows.append(h_convolve_x)
        matrix = array_ops.stack(matrix_rows, axis=-1)
        return math_ops.cast(matrix, dtype)
  def _spectrum_to_circulant_2d(self, spectrum, shape, dtype):
    """Creates a block circulant matrix from a spectrum.

    Intentionally done in an explicit yet inefficient way.  This provides a
    cross check to the main code that uses fancy reshapes.

    Args:
      spectrum: Float or complex `Tensor`.
      shape:  Python list.  Desired shape of returned matrix.
      dtype:  Type to cast the returned matrix to.

    Returns:
      Block circulant (batch) matrix of desired `dtype`.
    """
    spectrum = _to_complex(spectrum)
    spectrum_shape = self._shape_to_spectrum_shape(shape)
    domain_dimension = spectrum_shape[-1]
    if not domain_dimension:
      return array_ops.zeros(shape, dtype)

    block_shape = spectrum_shape[-2:]

    # Explicitly compute the action of spectrum on basis vectors.
    matrix_rows = []
    for n0 in range(block_shape[0]):
      for n1 in range(block_shape[1]):
        x = np.zeros(block_shape)
        # x is a basis vector.
        x[n0, n1] = 1.0
        fft_x = fft_ops.fft2d(x.astype(np.complex64))
        h_convolve_x = fft_ops.ifft2d(spectrum * fft_x)
        # We want the flat version of the action of the operator on a basis
        # vector, not the block version.
        h_convolve_x = array_ops.reshape(h_convolve_x, shape[:-1])
        matrix_rows.append(h_convolve_x)
    matrix = array_ops.stack(matrix_rows, axis=-1)
    return math_ops.cast(matrix, dtype)
def tf_unmasked_adj_op(x, idx=0):
    scaling_norm = tf.dtypes.cast(tf.math.sqrt(tf.dtypes.cast(tf.math.reduce_prod(tf.shape(x)[1:3]), 'float32')), x.dtype)
    return scaling_norm * tf.expand_dims(fftshift(ifft2d(ifftshift(x[..., idx], axes=[1, 2])), axes=[1, 2]), axis=-1)
def extract_smaps(kspace, low_freq_percentage=8, background_thresh=4e-6):
    """Extract raw sensitivity maps for kspaces

    This function will first select a low frequency region in all the kspaces,
    then Fourier invert it, and finally perform a normalisation by the root
    sum-of-square.
    kspace has to be of shape: nslices x ncoils x height x width

    Arguments:
        kspace (tf.Tensor): the kspace whose sensitivity maps you want extracted.
        low_freq_percentage (int): the low frequency region to consider for
            sensitivity maps extraction, given as a percentage of the width of
            the kspace. In fastMRI, it's 8 for an acceleration factor of 4, and
            4 for an acceleration factor of 8. Defaults to 8.
        background_thresh (float): unused for now, will later allow to have
            thresholded sensitivity maps.

    Returns:
        tf.Tensor: extracted raw sensitivity maps.
    """
    n_low_freq = tf.cast(tf.shape(kspace)[-2:] * low_freq_percentage / 100, tf.int32)
    center_dimension = tf.cast(tf.shape(kspace)[-2:] / 2, tf.int32)
    low_freq_lower_locations = center_dimension - tf.cast(n_low_freq / 2, tf.int32)
    low_freq_upper_locations = center_dimension + tf.cast(n_low_freq / 2, tf.int32)
    ###
    # NOTE: the following stands for in numpy:
    # low_freq_mask = np.zeros_like(kspace)
    # low_freq_mask[
    #     ...,
    #     low_freq_lower_locations[0]:low_freq_upper_locations[0],
    #     low_freq_lower_locations[1]:low_freq_upper_locations[1]
    # ] = 1
    x_range = tf.range(low_freq_lower_locations[0], low_freq_upper_locations[0])
    y_range = tf.range(low_freq_lower_locations[1], low_freq_upper_locations[1])
    X_range, Y_range = tf.meshgrid(x_range, y_range)
    X_range = tf.reshape(X_range, (-1,))
    Y_range = tf.reshape(Y_range, (-1,))
    low_freq_mask_indices = tf.stack([X_range, Y_range], axis=-1)
    # we have to transpose because only the first dimension can be indexed in
    # scatter_nd
    scatter_nd_perm = [2, 3, 0, 1]
    low_freq_mask = tf.scatter_nd(
        indices=low_freq_mask_indices,
        updates=tf.ones([
            tf.size(X_range),
            tf.shape(kspace)[0],
            tf.shape(kspace)[1]],
        ),
        shape=[tf.shape(kspace)[i] for i in scatter_nd_perm],
    )
    low_freq_mask = tf.transpose(low_freq_mask, perm=scatter_nd_perm)
    ###
    low_freq_kspace = kspace * tf.cast(low_freq_mask, kspace.dtype)
    shifted_kspace = ifftshift(low_freq_kspace, axes=[2, 3])
    coil_image_low_freq_shifted = ifft2d(shifted_kspace)
    coil_image_low_freq = fftshift(coil_image_low_freq_shifted, axes=[2, 3])
    # no need to norm this since they all have the same norm
    low_freq_rss = tf.norm(coil_image_low_freq, axis=1)
    coil_smap = coil_image_low_freq / low_freq_rss[:, None]
    # for now we do not perform background removal based on low_freq_rss
    # could be done with 1D k-means or fixed background_thresh, with tf.where
    return coil_smap