def get_linear_n_regularization_operator(self,
                                          gradient_formulation,
                                          wavelet_name,
                                          dimension=2,
                                          nb_scale=3,
                                          n_coils=1,
                                          n_jobs=1,
                                          verbose=0):
     # A helper function to obtain linear and regularization operator
     try:
         linear_op = WaveletN(
             nb_scale=nb_scale,
             wavelet_name=wavelet_name,
             dim=dimension,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     except ValueError:
         # TODO this is a hack and we need to have a separate WaveletUD2.
         # For Undecimated wavelets, the wavelet_name is wavelet_id
         linear_op = WaveletUD2(
             wavelet_id=wavelet_name,
             nb_scale=nb_scale,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     if gradient_formulation == 'synthesis':
         regularizer_op = SparseThreshold(Identity(), 0, thresh_type="soft")
     elif gradient_formulation == "analysis":
         regularizer_op = SparseThreshold(linear_op, 0, thresh_type="soft")
     return linear_op, regularizer_op
Example #2
0
class simplex_threshold(object):
    """ Simplex Threshold proximity operator
    This class stacks the proximity operators Simplex and Threshold
        
    Calls:
    
    * :func:`proxs.Simplex`

    """
    def __init__(self,linop, weights,mass=None,pos_en=False):
        self.linop = linop
        self.weights = weights
        self.thresh = SparseThreshold(self.linop, self.weights)
        self.simplex = Simplex(mass=mass,pos_en=pos_en)

    def update_weights(self, weights):
        """Update weights

        This method update the values of the weights

        Parameters
        ----------
        weights : np.ndarray
            Input array of weights

        """
        self.weights = weights
        self.thresh = SparseThreshold(self.linop, weights)

    def op(self, data, extra_factor=1.0):
        return np.array([self.simplex.op(data[0]),self.thresh.op(data[1],extra_factor=extra_factor)])
Example #3
0
 def set_objective(self, X, y, lmbd):
     self.X, self.y, self.lmbd = X, y, lmbd
     n_features = self.X.shape[1]
     if self.restart_strategy == 'greedy':
         min_beta = 1.0
         s_greedy = 1.1
         p_lazy = 1.0
         q_lazy = 1.0
     else:
         min_beta = None
         s_greedy = None
         p_lazy = 1 / 30
         q_lazy = 1 / 10
     self.fb = ForwardBackward(
         x=np.zeros(n_features),  # this is the coefficient w
         grad=GradBasic(
             input_data=y,
             op=lambda w: self.X@w,
             trans_op=lambda res: self.X.T@res,
         ),
         prox=SparseThreshold(Identity(), lmbd),
         beta_param=1.0,
         min_beta=min_beta,
         metric_call_period=None,
         restart_strategy=self.restart_strategy,
         xi_restart=0.96,
         s_greedy=s_greedy,
         p_lazy=p_lazy,
         q_lazy=q_lazy,
         auto_iterate=False,
         progress=False,
         cost=None,
     )
Example #4
0
    def update_weights(self, weights):
        """Update weights

        This method update the values of the weights

        Parameters
        ----------
        weights : np.ndarray
            Input array of weights

        """
        self.weights = weights
        self.thresh = SparseThreshold(self.linop, weights)
def reco_wav(kspace,
             gradient_op,
             mu=1 * 1e-8,
             max_iter=10,
             nb_scales=4,
             wavelet_name='db4'):
    # for now this is only working with my fork of pysap-fastMRI
    # I will get it changed soon so that we don't need to ask for a specific
    # pysap-mri install
    from ..wavelets import WaveletDecimated
    from mri.numerics.reconstruct import sparse_rec_fista

    linear_op = WaveletDecimated(
        nb_scale=nb_scales,
        wavelet_name=wavelet_name,
        padding='periodization',
    )

    prox_op = LinearCompositionProx(
        linear_op=linear_op,
        prox_op=SparseThreshold(Identity(), None, thresh_type="soft"),
    )
    gradient_op.obs_data = kspace
    cost_op = None
    x_final, _, _, _ = sparse_rec_fista(
        gradient_op=gradient_op,
        linear_op=Identity(),
        prox_op=prox_op,
        cost_op=cost_op,
        xi_restart=0.96,
        s_greedy=1.1,
        mu=mu,
        restart_strategy='greedy',
        pov='analysis',
        max_nb_of_iter=max_iter,
        metrics=None,
        metric_call_period=1,
        verbose=0,
        progress=False,
    )
    x_final = np.abs(x_final)
    x_final = crop_center(x_final, 320)
    return x_final
Example #6
0
 def set_objective(self, X, y, lmbd):
     self.X, self.y, self.lmbd = X, y, lmbd
     n_features = self.X.shape[1]
     sigma_bar = 0.96
     var_init = np.zeros(n_features)
     self.pogm = POGM(
         x=var_init,  # this is the coefficient w
         u=var_init,
         y=var_init,
         z=var_init,
         grad=GradBasic(
             op=lambda w: self.X @ w,
             trans_op=lambda res: self.X.T @ res,
             data=y,
         ),
         prox=SparseThreshold(Identity(), lmbd),
         beta_param=1.0,
         metric_call_period=None,
         sigma_bar=sigma_bar,
         auto_iterate=False,
         progress=False,
         cost=None,
     )
Example #7
0
def generate_operators(data,
                       wavelet_name,
                       samples,
                       nb_scales=4,
                       non_cartesian=False,
                       uniform_data_shape=None,
                       gradient_space="analysis"):
    """ Function that ease the creation of a set of common operators.

    .. note:: At the moment, supports only 2D data.

    Parameters
    ----------
    data: ndarray
        the data to reconstruct: observation are expected in Fourier space.
    wavelet_name: str
        the wavelet name to be used during the decomposition.
    samples: np.ndarray
        the mask samples in the Fourier domain.
    nb_scales: int, default 4
        the number of scales in the wavelet decomposition.
    non_cartesian: bool (optional, default False)
        if set, use the nfftw rather than the fftw. Expect an 1D input dataset.
    uniform_data_shape: uplet (optional, default None)
        the shape of the matrix containing the uniform data. Only required
        for non-cartesian reconstructions.
    gradient_space: str (optional, default 'analysis')
        the space where the gradient operator is defined: 'analysis' or
        'synthesis'

    Returns
    -------
    gradient_op: instance of class GradBase
        the gradient operator.
    linear_op: instance of LinearBase
        the linear operator: seek the sparsity, ie. a wavelet transform.
    prox_op: instance of ProximityParent
        the proximal operator.
    cost_op: instance of costObj
        the cost function used to check for convergence during the
        optimization.
    """
    # Local imports
    from mri.numerics.cost import DualGapCost
    from mri.numerics.linear import Wavelet2
    from mri.numerics.fourier import FFT2
    from mri.numerics.fourier import NFFT
    from mri.numerics.gradient import GradAnalysis2
    from mri.numerics.gradient import GradSynthesis2
    from modopt.opt.proximity import SparseThreshold

    # Check input parameters
    if gradient_space not in ("analysis", "synthesis"):
        raise ValueError(
            "Unsupported gradient space '{0}'.".format(gradient_space))
    if non_cartesian and data.ndim != 1:
        raise ValueError("Expect 1D data with the non-cartesian option.")
    elif non_cartesian and uniform_data_shape is None:
        raise ValueError("Need to set the 'uniform_data_shape' parameter with "
                         "the non-cartesian option.")
    elif not non_cartesian and data.ndim != 2:
        raise ValueError("At the moment, this functuion only supports 2D "
                         "data.")

    # Define the gradient/linear/fourier operators
    linear_op = Wavelet2(nb_scale=nb_scales, wavelet_name=wavelet_name)
    if non_cartesian:
        fourier_op = NFFT(samples=samples, shape=uniform_data_shape)
    else:
        fourier_op = FFT2(samples=samples, shape=data.shape)
    if gradient_space == "synthesis":
        gradient_op = GradSynthesis2(data=data,
                                     linear_op=linear_op,
                                     fourier_op=fourier_op)
    else:
        gradient_op = GradAnalysis2(data=data, fourier_op=fourier_op)

    # Define the proximity dual/primal operator
    prox_op = SparseThreshold(linear_op, None, thresh_type="soft")

    # Define the cost function
    if gradient_space == "synthesis":
        cost_op = None
    else:
        cost_op = DualGapCost(linear_op=linear_op,
                              initial_cost=1e6,
                              tolerance=1e-4,
                              cost_interval=1,
                              test_range=4,
                              verbose=0,
                              plot_output=None)

    return gradient_op, linear_op, prox_op, cost_op
#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name="sym8",
    nb_scales=4,
    dim=3,
    padding_mode="periodization",
)
regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,
    optimization_alg='fista',
    num_iterations=200,
)
image_rec = pysap.Image(data=np.abs(x_final))
Example #9
0
 def __init__(self,linop, weights,mass=None,pos_en=False):
     self.linop = linop
     self.weights = weights
     self.thresh = SparseThreshold(self.linop, self.weights)
     self.simplex = Simplex(mass=mass,pos_en=pos_en)
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name='sym8',
    nb_scale=4,
)
regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft")
# Setup Reconstructor
reconstructor = SelfCalibrationReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    kspace_portion=0.01,
    verbose=1,
)
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_obs,
    optimization_alg='fista',
    num_iterations=10,
)
image_rec = pysap.Image(data=x_final)
Example #11
0
def sparse_deconv_condatvu(data, psf, n_iter=300, n_reweights=1):
    """Sparse Deconvolution with Condat-Vu

    Parameters
    ----------
    data : np.ndarray
        Input data, 2D image
    psf : np.ndarray
        Input PSF, 2D image
    n_iter : int, optional
        Maximum number of iterations
    n_reweights : int, optional
        Number of reweightings

    Returns
    -------
    np.ndarray deconvolved image

    """

    # Print the algorithm set-up
    print(condatvu_logo())

    # Define the wavelet filters
    filters = (get_cospy_filters(
        data.shape, transform_name='LinearWaveletTransformATrousAlgorithm'))

    # Set the reweighting scheme
    reweight = cwbReweight(get_weights(data, psf, filters))

    # Set the initial variable values
    primal = np.ones(data.shape)
    dual = np.ones(filters.shape)

    # Set the gradient operators
    grad_op = GradBasic(data, lambda x: psf_convolve(x, psf),
                        lambda x: psf_convolve(x, psf, psf_rot=True))

    # Set the linear operator
    linear_op = WaveletConvolve2(filters)

    # Set the proximity operators
    prox_op = Positivity()
    prox_dual_op = SparseThreshold(linear_op, reweight.weights)

    # Set the cost function
    cost_op = costObj([grad_op, prox_op, prox_dual_op],
                      tolerance=1e-6,
                      cost_interval=1,
                      plot_output=True,
                      verbose=False)

    # Set the optimisation algorithm
    alg = Condat(primal,
                 dual,
                 grad_op,
                 prox_op,
                 prox_dual_op,
                 linear_op,
                 cost_op,
                 rho=0.8,
                 sigma=0.5,
                 tau=0.5,
                 auto_iterate=False)

    # Run the algorithm
    alg.iterate(max_iter=n_iter)

    # Implement reweigting
    for rw_num in range(n_reweights):
        print(' - Reweighting: {}'.format(rw_num + 1))
        reweight.reweight(linear_op.op(alg.x_final))
        alg.iterate(max_iter=n_iter)

    # Return the final result
    return alg.x_final
Example #12
0
def sparse_rec_fista(data, wavelet_name, samples, mu, nb_scales=4,
                     lambda_init=1.0, max_nb_of_iter=300, atol=1e-4,
                     non_cartesian=False, uniform_data_shape=None, verbose=0):
    """ The FISTA sparse reconstruction without reweightings.

    .. note:: At the moment, supports only 2D data.

    Parameters
    ----------
    data: ndarray
        the data to reconstruct (observation are expected in the Fourier
        space).
    wavelet_name: str
        the wavelet name to be used during the decomposition.
    samples: np.ndarray
        the mask samples in the Fourier domain.
    mu: float
       coefficient of regularization.
    nb_scales: int, default 4
        the number of scales in the wavelet decomposition.
    lambda_init: float, (default 1.0)
        initial value for the FISTA step.
    max_nb_of_iter: int (optional, default 300)
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    atol: float (optional, default 1e-4)
        tolerance threshold for convergence.
    non_cartesian: bool (optional, default False)
        if set, use the nfftw rather than the fftw. Expect an 1D input dataset.
    uniform_data_shape: uplet (optional, default None)
        the shape of the matrix containing the uniform data. Only required
        for non-cartesian reconstructions.
    verbose: int (optional, default 0)
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated FISTA solution.
    transform: a WaveletTransformBase derived instance
        the wavelet transformation instance.
    """
    # Check inputs
    start = time.clock()
    if non_cartesian and data.ndim != 1:
        raise ValueError("Expect 1D data with the non-cartesian option.")
    elif non_cartesian and uniform_data_shape is None:
        raise ValueError("Need to set the 'uniform_data_shape' parameter with "
                         "the non-cartesian option.")
    elif not non_cartesian and data.ndim != 2:
        raise ValueError("At the moment, this functuion only supports 2D "
                         "data.")

    # Define the gradient/linear/fourier operators
    linear_op = Wavelet2(
        nb_scale=nb_scales,
        wavelet_name=wavelet_name)
    if non_cartesian:
        fourier_op = NFFT2(
            samples=samples,
            shape=uniform_data_shape)
    else:
        fourier_op = FFT2(
            samples=samples,
            shape=data.shape)
    gradient_op = GradSynthesis2(
        data=data,
        linear_op=linear_op,
        fourier_op=fourier_op)

    # Define the initial primal and dual solutions
    x_init = np.zeros(fourier_op.shape, dtype=np.complex)
    alpha = linear_op.op(x_init)
    alpha[...] = 0.0

    # Welcome message
    if verbose > 0:
        print(fista_logo())
        print(" - mu: ", mu)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - data: ", data.shape)
        print(" - wavelet: ", wavelet_name, "-", nb_scales)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - image variable shape: ", x_init.shape)
        print(" - alpha variable shape: ", alpha.shape)
        print("-" * 40)

    # Define the proximity dual operator
    weights = copy.deepcopy(alpha)
    weights[...] = mu
    prox_op = SparseThreshold(linear_op, weights, thresh_type="soft")

    # Define the optimizer
    cost_op = None
    opt = ForwardBackward(
        x=alpha,
        grad=gradient_op,
        prox=prox_op,
        cost=cost_op,
        auto_iterate=False)

    # Perform the reconstruction
    end = time.clock()
    if verbose > 0:
        print("Starting optimization...")
    opt.iterate(max_iter=max_nb_of_iter)
    if verbose > 0:
        # cost_op.plot_cost()
        # print(" - final iteration number: ", cost_op._iteration)
        # print(" - final log10 cost value: ", np.log10(cost_op.cost))
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)
    x_final = linear_op.adj_op(opt.x_final)

    return x_final, linear_op.transform
Example #13
0
def sparse_rec_condatvu(data, wavelet_name, samples, nb_scales=4,
                        std_est=None, std_est_method=None, std_thr=2.,
                        mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0,
                        nb_of_reweights=1, max_nb_of_iter=150,
                        add_positivity=False, atol=1e-4, non_cartesian=False,
                        uniform_data_shape=None, verbose=0):
    """ The Condat-Vu sparse reconstruction with reweightings.

    .. note:: At the moment, supports only 2D data.

    Parameters
    ----------
    data: ndarray
        the data to reconstruct: observation are expected in Fourier space.
    wavelet_name: str
        the wavelet name to be used during the decomposition.
    samples: np.ndarray
        the mask samples in the Fourier domain.
    nb_scales: int, default 4
        the number of scales in the wavelet decomposition.
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    mu: float, default 1e-6
        regularization hyperparameter.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    add_positivity: bool, default False
        by setting this option, set the proximity operator to identity or
        positive.
    atol: float, default 1e-4
        tolerance threshold for convergence.
    non_cartesian: bool (optional, default False)
        if set, use the nfftw rather than the fftw. Expect an 1D input dataset.
    uniform_data_shape: uplet (optional, default None)
        the shape of the matrix containing the uniform data. Only required
        for non-cartesian reconstructions.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    transform: a WaveletTransformBase derived instance
        the wavelet transformation instance.
    """
    # Check inputs
    start = time.clock()
    if non_cartesian and data.ndim != 1:
        raise ValueError("Expect 1D data with the non-cartesian option.")
    elif non_cartesian and uniform_data_shape is None:
        raise ValueError("Need to set the 'uniform_data_shape' parameter with "
                         "the non-cartesian option.")
    elif not non_cartesian and data.ndim != 2:
        raise ValueError("At the moment, this functuion only supports 2D "
                         "data.")
    if std_est_method not in (None, "primal", "dual"):
        raise ValueError(
            "Unrecognize std estimation method '{0}'.".format(std_est_method))

    # Define the gradient/linear/fourier operators
    linear_op = Wavelet2(
        nb_scale=nb_scales,
        wavelet_name=wavelet_name)
    if non_cartesian:
        data_shape = uniform_data_shape
        fourier_op = NFFT2(
            samples=samples,
            shape=uniform_data_shape)
    else:
        data_shape = data.shape
        fourier_op = FFT2(
            samples=samples,
            shape=data.shape)
    gradient_op = GradAnalysis2(
        data=data,
        fourier_op=fourier_op)

    # Define the initial primal and dual solutions
    x_init = np.zeros(fourier_op.shape, dtype=np.complex)
    weights = linear_op.op(x_init)

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # Case1: estimate the noise std in the image domain
    if std_est_method == "primal":
        if std_est is None:
            std_est = sigma_mad(gradient_op.MtX(data))
        weights[...] = std_thr * std_est
        reweight_op = cwbReweight(weights)
        prox_dual_op = SparseThreshold(linear_op, reweight_op.weights,
                                       thresh_type="soft")

    # Case2: estimate the noise std in the sparse wavelet domain
    elif std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        prox_dual_op = SparseThreshold(linear_op, weights=reweight_op.weights,
                                       thresh_type="soft")

    # Case3: manual regularization mode, no reweighting
    else:
        weights[...] = mu
        reweight_op = None
        prox_dual_op = SparseThreshold(linear_op, weights, thresh_type="soft")
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(data_shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst/2 + sigma * norm**2 + eps)
    convergence_test = (
        1.0 / tau - sigma * norm ** 2 >= lipschitz_cst / 2.0)

    # Define initial primal and dual solutions
    primal = np.zeros(data_shape, dtype=np.complex)
    dual = linear_op.op(primal)
    dual[...] = 0.0

    # Welcome message
    if verbose > 0:
        print(condatvu_logo())
        print(" - mu: ", mu)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", data.shape)
        print(" - wavelet: ", wavelet_name, "-", nb_scales)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    # Define the proximity operator
    if add_positivity:
        prox_op = Positive()
    else:
        prox_op = Identity()

    # Define the cost function
    cost_op = DualGapCost(
        linear_op=linear_op,
        initial_cost=1e6,
        tolerance=1e-4,
        cost_interval=1,
        test_range=4,
        verbose=0,
        plot_output=None)

    # Define the optimizer
    opt = Condat(
        x=primal,
        y=dual,
        grad=gradient_op,
        prox=prox_op,
        prox_dual=prox_dual_op,
        linear=linear_op,
        cost=cost_op,
        rho=relaxation_factor,
        sigma=sigma,
        tau=tau,
        rho_update=None,
        sigma_update=None,
        tau_update=None,
        auto_iterate=False)

    # Perform the first reconstruction
    if verbose > 0:
        print("Starting optimization...")
    opt.iterate(max_iter=max_nb_of_iter)

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        if std_est_method == "primal":
            reweight_op.reweight(linear_op.op(opt._x_new))
        else:
            std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose > 0:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        prox_dual_op.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.clock()
    if verbose > 0:
        print(" - final iteration number: ", cost_op._iteration)
        print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    linear_op.transform.analysis_data = unflatten(
        opt.y_final, linear_op.coeffs_shape)

    return x_final, linear_op.transform