예제 #1
0
    def generate_test_FFT(self, shape, field_scale):
        """ Factorized code to test 2D and 3D wrapped FFT with
        different homogeneous B0 field shifts at constant time.
        """
        for L, i, n_coils in product(self.L, range(self.max_iter),
                                     self.n_coils):
            mask = np.random.randint(2, size=shape)
            field_shift = field_scale * np.random.randint(-150, 150)
            field_map = field_shift * np.ones(shape)

            # Prepare reference and wrapper operators
            fourier_op = FFT(mask=mask, shape=shape, n_coils=n_coils)
            wrapper_op = ORCFFTWrapper(fourier_op, field_map=field_map,
                                       time_vec=np.ones(shape[0]),
                                       mask=np.ones(shape),
                                       num_interpolators=L,
                                       n_bins=self.n_bins)

            # Forward operator
            img = np.squeeze(np.random.randn(n_coils, *shape) \
                      + 1j * np.random.randn(n_coils, *shape))
            ksp_fft = fourier_op.op(img)
            ksp_wra = wrapper_op.op(img * np.exp(-2j * np.pi * field_shift))
            np.testing.assert_allclose(ksp_fft, ksp_wra, rtol=self.rtol)

            # Adjoint operator
            ksp = np.squeeze(np.random.randn(n_coils, *shape) \
                      + 1j * np.random.randn(n_coils, *shape))
            img_fft = fourier_op.adj_op(ksp)
            img_wra = wrapper_op.adj_op(ksp * np.exp(2j * np.pi * field_shift))
            np.testing.assert_allclose(img_fft, img_wra, rtol=self.rtol)
예제 #2
0
def get_operators(
    kspace_data,
    loc,
    mask,
    fourier_type=1,
    max_iter=80,
    regularisation=None,
    linear=None,
):
    """Create the various operators from the config file."""
    n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0]
    shape = kspace_data.shape[-2:]
    if fourier_type == 0:  # offline reconstruction
        kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data,
                                               mask=mask,
                                               max_iter=max_iter)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 1:  # online type I reconstruction
        kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 2:  # online type II reconstruction
        kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = ColumnFFT(shape=shape, n_coils=n_coils)
    else:
        raise NotImplementedError
    if linear is None:
        linear_op = Identity()
    else:
        lin_cls = linear.pop("class", None)
        if lin_cls == "WaveletN":
            linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear)
            linear_op.op(np.zeros_like(kspace_data))
        elif lin_cls == "Identity":
            linear_op = Identity()
        else:
            raise NotImplementedError

    prox_op = IdentityProx()
    if regularisation is not None:
        reg_cls = regularisation.pop("class")
        if reg_cls == "LASSO":
            prox_op = LASSO(weights=regularisation["weights"])
        if reg_cls == "GroupLASSO":
            prox_op = GroupLASSO(weights=regularisation["weights"])
        elif reg_cls == "OWL":
            prox_op = OWL(**regularisation,
                          n_coils=n_coils,
                          bands_shape=linear_op.coeffs_shape)
        elif reg_cls == "IdentityProx":
            prox_op = IdentityProx()
            linear_op = Identity()
    return kspace_generator, fourier_op, linear_op, prox_op
예제 #3
0
 def test_FFT(self):
     """Test the adjoint operator for the 2D non-Cartesian Fourier transform
     """
     for i, num_coil in product(range(self.max_iter), self.num_channels):
         _mask = np.random.randint(2, size=(self.N, self.N))
         _samples = convert_mask_to_locations(_mask)
         print("Process FFT test '{0}'...", i)
         fourier_op_dir = FFT(samples=_samples,
                              shape=(self.N, self.N),
                              n_coils=num_coil)
         fourier_op_adj = FFT(samples=_samples,
                              shape=(self.N, self.N),
                              n_coils=num_coil)
         Img = np.squeeze(
             np.random.randn(num_coil, self.N, self.N) +
             1j * np.random.randn(num_coil, self.N, self.N))
         f = np.squeeze(
             np.random.randn(num_coil, self.N, self.N) +
             1j * np.random.randn(num_coil, self.N, self.N))
         f_p = fourier_op_dir.op(Img)
         I_p = fourier_op_adj.adj_op(f)
         x_d = np.vdot(Img, I_p)
         x_ad = np.vdot(f_p, f)
         np.testing.assert_allclose(x_d, x_ad, rtol=1e-10)
     print(" FFT adjoint test passes")
# View Input
# image.show()
# mask.show()

#############################################################################
# Generate the kspace
# -------------------
#
# From the 3D Orange volume and the acquisition mask, we retrospectively
# undersample the k-space using a cartesian acquisition mask
# We then reconstruct the zero order solution as a baseline

# Get the locations of the kspace samples
kspace_loc = convert_mask_to_locations(mask.data)
# Generate the subsampled kspace
fourier_op = FFT(samples=kspace_loc, shape=image.shape)
kspace_data = fourier_op.op(image)

# Zero order solution
image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data),
                         metadata=image.metadata)
# image_rec0.show()

# Calculate SSIM
base_ssim = ssim(image_rec0, image)
print(base_ssim)

#############################################################################
# FISTA optimization
# ------------------
#
# View Input
# image.show()
# mask.show()

#############################################################################
# Generate the kspace
# -------------------
#
# From the 2D brain slice and the acquisition mask, we retrospectively
# undersample the k-space using a cartesian acquisition mask
# We then reconstruct the zero order solution as a baseline

# Get the locations of the kspace samples and the associated observations
fourier_op = FFT(samples=kspace_loc,
                 shape=image.shape,
                 n_coils=cartesian_ref_image.shape[0])
kspace_obs = fourier_op.op(cartesian_ref_image)

# Zero Filled reconstruction
zero_filled = fourier_op.adj_op(kspace_obs)
image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0)))
# image_rec0.show()
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
예제 #6
0
from mri.operators.utils import convert_mask_to_locations
from mri.reconstructors import SingleChannelReconstructor
from mri.scripts.gridsearch import launch_grid

from pysap.data import get_sample_data

from modopt.math.metrics import ssim
from modopt.opt.proximity import SparseThreshold
from modopt.opt.linear import Identity
import numpy as np

# Load MR data and obtain kspace
image = get_sample_data('2d-mri')
mask = get_sample_data("cartesian-mri-mask")
kspace_loc = convert_mask_to_locations(mask.data)
fourier_op = FFT(samples=kspace_loc, shape=image.shape)
kspace_data = fourier_op.op(image.data)
# Define the keyword dictionaries based on convention
ref = image
metrics = {
    'ssim': {
        'metric': ssim,
        'mapping': {
            'x_new': 'test',
            'y_new': None
        },
        'cst_kwargs': {
            'ref': image,
            'mask': None
        },
        'early_stopping': True,
예제 #7
0
def create_cartesian_metrics(online_pb,
                             real_img,
                             final_mask,
                             final_k,
                             estimates=None):

    metrics_fourier_op = FFT(
        shape=final_k.shape[-2:],
        n_coils=final_k.shape[0] if final_k.ndim == 3 else 1,
        mask=final_mask)
    metrics_gradient_op = OnlineGradAnalysis(fourier_op=metrics_fourier_op)
    metrics_gradient_op.obs_data = final_k
    square_mask = np.zeros(real_img.shape)
    real_img_size = real_img.shape
    img_size = [min(real_img.shape)] * 2
    square_mask[real_img_size[0] // 2 -
                img_size[0] // 2:real_img_size[0] // 2 + img_size[0] // 2,
                real_img_size[1] // 2 -
                img_size[1] // 2:real_img_size[1] // 2 + img_size[1] // 2] = 1

    def data_res_on(x):
        if isinstance(online_pb.gradient_op, OnlineGradSynthesis):
            return online_pb.gradient_op.cost(online_pb.linear_op.op(x))
        return online_pb.gradient_op.cost(x)

    metrics = {
        'psnr': {
            'metric': psnr_ssos,
            'mapping': {
                'x_new': 'test'
            },
            'early_stopping': False,
            'cst_kwargs': {
                'ref': real_img,
                'mask': square_mask
            },
        },
        'ssim': {
            'metric': ssim_ssos,
            'mapping': {
                'x_new': 'test'
            },
            'cst_kwargs': {
                'ref': real_img,
                'mask': square_mask
            },
            'early_stopping': False,
        },
        'data_res_off': {
            'metric': lambda x: metrics_gradient_op.cost(x),
            'mapping': {
                'x_new': 'x'
            },
            'early_stopping': False,
            'cst_kwargs': dict(),
        },
        'data_res_on': {
            'metric': data_res_on,
            'mapping': {
                'x_new': 'x'
            },
            'early_stopping': False,
            'cst_kwargs': dict(),
        },
        'x_new': {
            'metric': lambda x: ssos(x),
            'mapping': {
                'x_new': 'x'
            },
            'early_stopping': False,
            'cst_kwargs': dict(),
        },
    }
    if online_pb.opt == 'condatvu':
        metrics['reg_res'] = {
            'metric': lambda y: online_pb.prox_op.cost(y),
            'mapping': {
                'y_new': 'y'
            },
            'early_stopping': False,
            'cst_kwargs': dict(),
        }
    else:
        metrics['reg_res'] = {
            'metric':
            lambda x: online_pb.prox_op.cost(online_pb.linear_op.op(x)),
            'mapping': {
                'x_new': 'x'
            },
            'early_stopping': False,
            'cst_kwargs': dict(),
        }
    metrics_config = {
        'metrics': metrics,
        'cost_op_kwargs': {
            "cost_interval": 1
        },
        'metric_call_period': 1,
        'estimate_call_period': estimates,
    }
    return metrics_config