Ejemplo n.º 1
0
 def test_Wavelet3D_PyWt(self):
     """Test the adjoint operator for the 3D Wavelet transform
     """
     for ch in self.num_channels:
         print("Testing with Num Channels : " + str(ch))
         for i in range(self.max_iter):
             print("Process Wavelet3D PyWt test '{0}'...", i)
             wavelet_op_adj = WaveletN(
                 wavelet_name="sym8",
                 nb_scale=4,
                 dim=3,
                 padding_mode='periodization',
                 n_coils=ch,
                 n_jobs=-1,
             )
             Img = np.squeeze(
                 np.random.randn(ch, self.N, self.N, self.N) +
                 1j * np.random.randn(ch, self.N, self.N, self.N)
             )
             f_p = wavelet_op_adj.op(Img)
             f = (np.random.randn(*f_p.shape) +
                  1j * np.random.randn(*f_p.shape))
             I_p = wavelet_op_adj.adj_op(f)
             x_d = np.vdot(Img, I_p)
             x_ad = np.vdot(f_p, f)
             np.testing.assert_allclose(x_d, x_ad, rtol=1e-5)
     print(" Wavelet3 adjoint test passes")
Ejemplo n.º 2
0
    def test_weighted_sparse_threshold_weights(self):
        # Test the weighted sparse threshold operator
        num_scales = 3
        linear_op = WaveletN('sym8', nb_scales=num_scales)
        coeff = linear_op.op(np.zeros((128, 128)))
        coeffs_shape = linear_op.coeffs_shape
        scales_shape = np.unique(coeffs_shape, axis=0)
        constant_weights = WeightedSparseThreshold(
            weights=1e-10,
            coeffs_shape=coeffs_shape,
        )
        out = constant_weights.op(np.random.random(coeff.shape))
        assert np.all(constant_weights.weights[:np.prod(coeffs_shape[0])] == 0)
        assert np.all(
            constant_weights.weights[np.prod(coeffs_shape[0]):] == 1e-10
        )

        # Scale weights
        custom_scale_weights = np.arange(num_scales + 1)
        scale_based = WeightedSparseThreshold(
            weights=custom_scale_weights,
            coeffs_shape=coeffs_shape,
            weight_type='scale_based',
            zero_weight_coarse=False,
        )
        out = scale_based.op(np.random.random(coeff.shape))
        start = 0
        for i, scale_shape in enumerate(scales_shape):
            scale_sz = np.prod(scale_shape)
            stop = start + scale_sz * np.sum(scale_shape == coeffs_shape)
            np.testing.assert_equal(
                scale_based.weights[start:stop],
                custom_scale_weights[i],
            )
            start = stop

        # Custom Weights
        custom_weights = np.random.random(coeff.shape)
        custom = WeightedSparseThreshold(
            weights=custom_weights,
            coeffs_shape=coeffs_shape,
            weight_type='custom',
        )
        out = custom.op(np.random.random(coeff.shape))
        assert np.all(custom.weights[:np.prod(coeffs_shape[0])] == 0)
        np.testing.assert_equal(
            custom.weights[np.prod(coeffs_shape[0]):],
            custom_weights[np.prod(coeffs_shape[0]):],
        )
Ejemplo n.º 3
0
 def test_Wavelet2D_PyWt(self):
     """Test the adjoint operator for the 2D Wavelet transform
     """
     for i in range(self.max_iter):
         print("Process Wavelet2D PyWt test '{0}'...", i)
         wavelet_op_adj = WaveletN(wavelet_name="sym8", nb_scale=4)
         Img = (np.random.randn(self.N, self.N) +
                1j * np.random.randn(self.N, self.N))
         f_p = wavelet_op_adj.op(Img)
         f = (np.random.randn(*f_p.shape) +
              1j * np.random.randn(*f_p.shape))
         I_p = wavelet_op_adj.adj_op(f)
         x_d = np.vdot(Img, I_p)
         x_ad = np.vdot(f_p, f)
         np.testing.assert_allclose(x_d, x_ad, rtol=1e-6)
     print(" Wavelet2 adjoint test passes")
Ejemplo n.º 4
0
def get_operators(
    kspace_data,
    loc,
    mask,
    fourier_type=1,
    max_iter=80,
    regularisation=None,
    linear=None,
):
    """Create the various operators from the config file."""
    n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0]
    shape = kspace_data.shape[-2:]
    if fourier_type == 0:  # offline reconstruction
        kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data,
                                               mask=mask,
                                               max_iter=max_iter)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 1:  # online type I reconstruction
        kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 2:  # online type II reconstruction
        kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = ColumnFFT(shape=shape, n_coils=n_coils)
    else:
        raise NotImplementedError
    if linear is None:
        linear_op = Identity()
    else:
        lin_cls = linear.pop("class", None)
        if lin_cls == "WaveletN":
            linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear)
            linear_op.op(np.zeros_like(kspace_data))
        elif lin_cls == "Identity":
            linear_op = Identity()
        else:
            raise NotImplementedError

    prox_op = IdentityProx()
    if regularisation is not None:
        reg_cls = regularisation.pop("class")
        if reg_cls == "LASSO":
            prox_op = LASSO(weights=regularisation["weights"])
        if reg_cls == "GroupLASSO":
            prox_op = GroupLASSO(weights=regularisation["weights"])
        elif reg_cls == "OWL":
            prox_op = OWL(**regularisation,
                          n_coils=n_coils,
                          bands_shape=linear_op.coeffs_shape)
        elif reg_cls == "IdentityProx":
            prox_op = IdentityProx()
            linear_op = Identity()
    return kspace_generator, fourier_op, linear_op, prox_op
Ejemplo n.º 5
0
 def test_Wavelet2D_ISAP(self):
     """Test the adjoint operator for the 2D Wavelet transform
     """
     for ch in self.num_channels:
         print("Testing with Num Channels : " + str(ch))
         for i in range(self.max_iter):
             print("Process Wavelet2D_ISAP test '{0}'...", i)
             wavelet_op_adj = WaveletN(wavelet_name="HaarWaveletTransform",
                                       nb_scale=4,
                                       n_coils=ch,
                                       n_jobs=2)
             Img = np.squeeze(
                 np.random.randn(ch, self.N, self.N) +
                 1j * np.random.randn(ch, self.N, self.N))
             f_p = wavelet_op_adj.op(Img)
             f = (np.random.randn(*f_p.shape) +
                  1j * np.random.randn(*f_p.shape))
             I_p = wavelet_op_adj.adj_op(f)
             x_d = np.vdot(Img, I_p)
             x_ad = np.vdot(f_p, f)
             np.testing.assert_allclose(x_d, x_ad, rtol=1e-5)
     print(" Wavelet2 adjoint test passes")
# Calculate SSIM
base_ssim = ssim(image_rec0, image)
print(base_ssim)

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name="sym8",
    nb_scales=4,
    dim=3,
    padding_mode="periodization",
)
regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,
    optimization_alg='fista',
Ejemplo n.º 7
0
def E(**kwargs):
    # --
    # -- Computes the cost of a given mask or parametrisation of mask --
    # --
    # INPUTS:   images: list of all images used to evaluate the reconstruction
    #           kspace_data: list of noised kspace data associated to these images
    #           param: list of lower and upper level parameters. Must contain:
    #                   - epsilon: weight of L2 norm in lower level reconstruction
    #                   - gamma: parameter for approximation of L1 norm
    #                   - c: weight of L for upper level cost function
    #                   - beta: weight of P for upper level cost function
    #               Remark: You should choose these parameters to have E(pk)~1 at the beginning (otherwire, L-BFGS-B may stop too early)
    #           fourier_op: Fourier operator for lower level reconstruction
    #           linear_op: Linear operator for lower level reconstruction
    #           mask_type (optional): learn cartesian mask if mask_type="cartesian", otherwise each point is independant.
    #                                   Other parametrisations may be implemented later.
    #           pk: initial mask. Mandatory if mask_type!="cartesian"
    #           lk: initial mask parametrisation. Mandatory if mask_type="cartesian"
    #           verbose (optional)

    # Getting parameters
    images = kwargs.get("images", None)
    kspace_data = kwargs.get("kspace_data", None)
    param = kwargs.get("param", None)

    samples = kwargs.get("samples", [])
    fourier_op = NonCartesianFFT(samples=samples,
                                 shape=images[0].shape,
                                 implementation='cpu')
    wavelet_name = kwargs.get("wavelet_name", "")
    wavelet_scale = kwargs.get("wavelet_scale", 1)
    linear_op = WaveletN(wavelet_name=wavelet_name,
                         nb_scale=wavelet_scale,
                         padding_mode="periodization")

    parallel = kwargs.get("parallel", False)
    mask_type = kwargs.get("mask_type", "")
    lk = kwargs.get("lk", None)
    pk = kwargs.get("pk", None)

    verbose = kwargs.get("verbose", 0)
    if verbose >= 0: print("\n\nEVALUATING E(p)")

    # Checking inputs errors
    if images is None or len(images) < 1:
        raise ValueError("At least one image is needed")
    if param is None: raise ValueError("Lower level parameters must be given")
    if len(images) != len(kspace_data):
        raise ValueError("Need as many images and kspace data")

    #Compute P(pk/lk)
    Ep = 0
    if mask_type == "cartesian":
        pk = pcart(lk)
        Ep = P(lk, param["beta"])
    elif mask_type == "radial_CO":
        n_rad = kwargs.get("n_rad")
        pk = pradCO(lk, n_rad)
        Ep = P(lk, param["beta"])
    else:
        Ep = P(pk, param["beta"])

    #Compute L(pk/lk)
    Nimages = len(images)
    if parallel:
        uk_list = Parallel(n_jobs=-1, verbose=0)(
            delayed(pdhg)(kspace_data[i],
                          pk,
                          samples=samples,
                          shape=images[0].shape,
                          wavelet_name=wavelet_name,
                          wavelet_scale=wavelet_scale,
                          param=param,
                          mask_type=mask_type,
                          const=kwargs.get("const", {}),
                          verbose=-1) for i in range(Nimages))
        Ep += np.sum(
            [L(uk_list[i][0], images[i], param["c"])
             for i in range(Nimages)]) / Nimages
    else:
        for i in range(Nimages):
            if verbose >= 0: print(f"\nImage {i+1}:")
            u0_mat, y = images[i], kspace_data[i]

            if verbose > 0: print("\nStarting PDHG")
            uk, _ = pdhg(y,
                         pk,
                         maxit=50,
                         fourier_op=fourier_op,
                         linear_op=linear_op,
                         **kwargs)

            Ep += L(uk, u0_mat, param["c"]) / Nimages

    return Ep
Ejemplo n.º 8
0
def grad_L(**kwargs):
    # --
    # -- Compute gradient of L with respect to p
    # --
    # INPUTS:   pk: Point where we want to compute the gradient
    #           u0_mat: Ground_truth image
    #           y: kspace data associated to u0_mat
    #           param: list of lower and upper level parameters. Must contain:
    #                   - epsilon: weight of L2 norm in lower level reconstruction
    #                   - gamma: parameter for approximation of L1 norm
    #                   - c: weight of L for upper level cost function
    #                   - beta: weight of P for upper level cost function
    #               Remark: You should choose these parameters to have E(pk)~1 at the beginning (otherwire, L-BFGS-B may stop too early)
    #           fourier_op: Fourier operator for lower level reconstruction
    #           linear_op: Linear operator for lower level reconstruction
    #
    #           max_cgiter (optional): maximum number of Conjugate Gradient iterations (default: 3000)
    #           cgtol (optional): tolerance of Conjugate Gradient iterations (default: 1e-6)
    #           compute_conv (optional): plot convergence if True (default: False)
    #           verbose (optional)

    # -- Getting parameters
    max_cgiter = kwargs.get("max_cgiter", 4000)
    cgtol = kwargs.get("cgtol", 1e-6)
    compute_conv = kwargs.get("compute_conv", False)
    mask_type = kwargs.get("mask_type", "")
    learn_mask = kwargs.get("learn_mask", True)
    learn_alpha = kwargs.get("learn_alpha", True)

    u0_mat = kwargs.get("u0_mat", None)
    param = kwargs.get("param", None)
    y = kwargs.get("y", None)
    pk = kwargs.get("pk", None)

    samples = kwargs.get("samples", [])
    fourier_op = NonCartesianFFT(samples=samples,
                                 shape=u0_mat.shape,
                                 implementation='cpu')
    wavelet_name = kwargs.get("wavelet_name", "")
    wavelet_scale = kwargs.get("wavelet_scale", 1)
    linear_op = WaveletN(wavelet_name=wavelet_name,
                         nb_scale=wavelet_scale,
                         padding_mode="periodization")
    const = kwargs.get("const", {})

    verbose = kwargs.get("verbose", 0)
    cg_conv = []

    if u0_mat is None:
        raise ValueError("A ground truth image u0_mat is needed")
    if y is None: raise ValueError("kspace data y are needed")
    n = len(u0_mat)

    # -- Compute uk from pk with lower level solver if not given
    if verbose >= 0: print("\nStarting PDHG")
    uk, _ = pdhg(y,
                 pk,
                 mask_type=mask_type,
                 fourier_op=fourier_op,
                 linear_op=linear_op,
                 param=param,
                 maxit=50,
                 verbose=verbose,
                 const=const)

    # -- Defining linear operator from pk and uk
    def mv(w):
        w_complex = np.reshape(w[:n**2] + 1j * w[n**2:], (n, n))
        fx = np.reshape(
            Du2_Etot(uk,
                     pk,
                     w_complex,
                     eps=param["epsilon"],
                     fourier_op=fourier_op,
                     y=y,
                     linear_op=linear_op,
                     gamma=param["gamma"]), (n**2, ))
        return np.concatenate([np.real(fx), np.imag(fx)])

    lin = LinearOperator((2 * n**2, 2 * n**2), matvec=mv)

    if verbose >= 0: print("\nStarting Conjugate Gradient method")
    t1 = time.time()
    B = np.reshape(Du_L(uk, u0_mat, param["c"]), (n**2, ))
    B_real = np.concatenate([np.real(B), np.imag(B)])

    def cgcall(x):
        #CG callback function to plot convergence
        if compute_conv:
            cg_conv.append(
                np.linalg.norm(lin(x) - B_real) / np.linalg.norm(B_real))

    x_inter, _ = cg(lin,
                    B_real,
                    tol=cgtol,
                    maxiter=max_cgiter,
                    callback=cgcall)
    if verbose >= 0:
        print(
            f"Finished in {time.time()-t1}s - ||Ax-b||/||b||: {np.linalg.norm(lin(x_inter)-B_real)/np.linalg.norm(B_real)}"
        )

    # -- Plotting
    if compute_conv:
        plt.plot(cg_conv)
        plt.yscale("log")
        plt.title("Convergence of the conjugate gradient")
        plt.xlabel("Number of iterations")
        plt.ylabel("||Ax-b||/||b||")
        #plt.savefig("Upper Level/CG_conv.png")

    if np.linalg.norm(lin(x_inter) - B_real) / np.linalg.norm(B_real) > 1e-3:
        return np.zeros(pk.shape)
    else:
        return -Dpu_Etot(uk,
                         pk,
                         np.reshape(x_inter[:n**2] + 1j * x_inter[n**2:],
                                    (n, n)),
                         eps=param["epsilon"],
                         fourier_op=fourier_op,
                         y=y,
                         linear_op=linear_op,
                         gamma=param["gamma"],
                         learn_mask=learn_mask,
                         learn_alpha=learn_alpha)
zero_filled = fourier_op.adj_op(kspace_obs)
image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0)))
# image_rec0.show()
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name='sym8',
    nb_scale=4,
)
regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft")
# Setup Reconstructor
reconstructor = SelfCalibrationReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    kspace_portion=0.01,
    verbose=1,
)
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_obs,
    optimization_alg='fista',
    num_iterations=10,
Ejemplo n.º 10
0
def pdhg(data, p, **kwargs):
    # --
    # -- MAIN LOWER LEVEL FUNCTION
    # --
    # INPUTS: - data: kspace measurements
    #         - p: p[:-1]=subsampling mask S(p), p[-1]=regularisation parameter alpha(p)
    #                   So len(p)=len(data)+1
    #         - fourier_op: fourier operator from a full mask of same shape as the final image.
    #         - linear_op: linear operator used in regularisation functions
    #                      For the moment, only use waveletN.
    #         - param: lower level energy parameters
    #                  Must contain parameters keys "epsilon" and "gamma".
    #           mask_type (optional): type of mask used ("cartesian", "radial"). Assume a cartesian mask if not given.
    # --
    # OPTIONAL INPUTS:
    #         - const: algorithm constants if we already know the values we want to use for tau and sigma
    #                  If not given, will compute them according to what is said in the article.
    #         - compute_energy: bool, we compute and return energy over iterations if True (default: False)
    #         - ground_truth: matrix representing the true image the data come from (default: None). If not None, we compute the ssim over iterations.
    #         - maxit,tol: We stop the algorithm when the norm of the difference between two steps
    #                      is smaller than tol or after maxit iterations (default: 200, 1e-4)
    # --
    # OUTPUTS: - uk: final image
    #          - norms(, energy, ssims): evolution of stopping criterion (and energy if compute_energy is True / ssims if ground_truth not None)

    fourier_op = kwargs.get("fourier_op", None)
    linear_op = kwargs.get("linear_op", None)
    param = kwargs.get("param", None)

    # Create fourier_op and linear_op if not given for multithreading
    if fourier_op is None:
        samples = kwargs.get("samples", [])
        shape = kwargs.get("shape", ())
        if samples is not None:
            fourier_op = NonCartesianFFT(samples=samples,
                                         shape=shape,
                                         implementation='cpu')
    if fourier_op is None:
        raise ValueError("A fourier operator fourier_op must be given")
    if linear_op is None:
        wavelet_name = kwargs.get("wavelet_name", "")
        wavelet_scale = kwargs.get("wavelet_scale", 1)
        if wavelet_name != "":
            linear_op = WaveletN(wavelet_name=wavelet_name,
                                 nb_scale=wavelet_scale,
                                 padding_mode="periodization")
    if linear_op is None:
        raise ValueError("A linear operator linear_op must be given")

    if param is None: raise ValueError("Lower level parameters must be given")
    mask_type = kwargs.get("mask_type", "")

    const = kwargs.get("const", {})
    compute_energy = kwargs.get("compute_energy", False)
    ground_truth = kwargs.get("ground_truth", None)
    maxit = kwargs.get("maxit", 200)
    tol = kwargs.get("tol", 1e-6)
    verbose = kwargs.get("verbose", 1)

    #Global parameters
    p, pn1 = p[:-1], p[-1]
    epsilon = param["epsilon"]
    gamma = param["gamma"]
    n_iter = 0
    #Algorithm constants
    const = compute_constants(param, const, p)
    if verbose >= 0: print("Sigma:", const["sigma"], "\nTau:", const["tau"])

    #Initializing
    uk = fourier_op.adj_op(p * data)
    vk = np.copy(uk)
    wk = linear_op.op(uk)
    uk_bar = np.copy(uk)
    norm = 2 * tol

    #For plots
    if compute_energy:
        energy = []
    if ground_truth is not None:
        ssims = []
    norms = []

    #Main loop
    t1 = time.time()
    while n_iter < maxit and norm > tol:
        uk, vk, wk, uk_bar, norm = step(uk, vk, wk, uk_bar, const, p, pn1,
                                        data, param, linear_op, fourier_op,
                                        mask_type)
        n_iter += 1

        #Saving informations
        norms.append(norm)
        if compute_energy:
            energy.append(
                energy_wavelet(uk, p, pn1, data, gamma, epsilon, linear_op,
                               fourier_op))
        if ground_truth is not None:
            ssims.append(ssim(uk, ground_truth))

        #Printing
        if n_iter % 10 == 0 and verbose > 0:
            if compute_energy:
                print(n_iter, " iterations:\nCost:", energy[-1], "\nNorm:",
                      norm, "\n")
            else:
                print(n_iter, " iterations:\nNorm:", norm, "\n")
    if verbose >= 0:
        print("Finished in", time.time() - t1, "seconds.")

    #Return
    if compute_energy and ground_truth is not None:
        return uk, norms, energy, ssims
    elif ground_truth is not None:
        return uk, norms, ssims
    elif compute_energy:
        return uk, norms, energy
    else:
        return uk, norms
image_rec0 = pysap.Image(data=grid_soln)
# image_rec0.show()
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# TODO get the right mu operator
# Setup the operators
linear_op = WaveletN(wavelet_name="sym8", nb_scales=4, dim=3)
regularizer_op = SparseThreshold(Identity(), 6 * 1e-9, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_obs,
    optimization_alg='fista',
    num_iterations=10,
)
Ejemplo n.º 12
0
image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0)))
# image_rec0.show()
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name='sym8',
    nb_scale=4,
    n_coils=cartesian_ref_image.shape[0],
)
regularizer_op = GroupLASSO(weights=6e-8)
# Setup Reconstructor
reconstructor = CalibrationlessReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_obs,
    optimization_alg='fista',
    num_iterations=300,