def test_Wavelet3D_PyWt(self): """Test the adjoint operator for the 3D Wavelet transform """ for ch in self.num_channels: print("Testing with Num Channels : " + str(ch)) for i in range(self.max_iter): print("Process Wavelet3D PyWt test '{0}'...", i) wavelet_op_adj = WaveletN( wavelet_name="sym8", nb_scale=4, dim=3, padding_mode='periodization', n_coils=ch, n_jobs=-1, ) Img = np.squeeze( np.random.randn(ch, self.N, self.N, self.N) + 1j * np.random.randn(ch, self.N, self.N, self.N) ) f_p = wavelet_op_adj.op(Img) f = (np.random.randn(*f_p.shape) + 1j * np.random.randn(*f_p.shape)) I_p = wavelet_op_adj.adj_op(f) x_d = np.vdot(Img, I_p) x_ad = np.vdot(f_p, f) np.testing.assert_allclose(x_d, x_ad, rtol=1e-5) print(" Wavelet3 adjoint test passes")
def test_weighted_sparse_threshold_weights(self): # Test the weighted sparse threshold operator num_scales = 3 linear_op = WaveletN('sym8', nb_scales=num_scales) coeff = linear_op.op(np.zeros((128, 128))) coeffs_shape = linear_op.coeffs_shape scales_shape = np.unique(coeffs_shape, axis=0) constant_weights = WeightedSparseThreshold( weights=1e-10, coeffs_shape=coeffs_shape, ) out = constant_weights.op(np.random.random(coeff.shape)) assert np.all(constant_weights.weights[:np.prod(coeffs_shape[0])] == 0) assert np.all( constant_weights.weights[np.prod(coeffs_shape[0]):] == 1e-10 ) # Scale weights custom_scale_weights = np.arange(num_scales + 1) scale_based = WeightedSparseThreshold( weights=custom_scale_weights, coeffs_shape=coeffs_shape, weight_type='scale_based', zero_weight_coarse=False, ) out = scale_based.op(np.random.random(coeff.shape)) start = 0 for i, scale_shape in enumerate(scales_shape): scale_sz = np.prod(scale_shape) stop = start + scale_sz * np.sum(scale_shape == coeffs_shape) np.testing.assert_equal( scale_based.weights[start:stop], custom_scale_weights[i], ) start = stop # Custom Weights custom_weights = np.random.random(coeff.shape) custom = WeightedSparseThreshold( weights=custom_weights, coeffs_shape=coeffs_shape, weight_type='custom', ) out = custom.op(np.random.random(coeff.shape)) assert np.all(custom.weights[:np.prod(coeffs_shape[0])] == 0) np.testing.assert_equal( custom.weights[np.prod(coeffs_shape[0]):], custom_weights[np.prod(coeffs_shape[0]):], )
def test_Wavelet2D_PyWt(self): """Test the adjoint operator for the 2D Wavelet transform """ for i in range(self.max_iter): print("Process Wavelet2D PyWt test '{0}'...", i) wavelet_op_adj = WaveletN(wavelet_name="sym8", nb_scale=4) Img = (np.random.randn(self.N, self.N) + 1j * np.random.randn(self.N, self.N)) f_p = wavelet_op_adj.op(Img) f = (np.random.randn(*f_p.shape) + 1j * np.random.randn(*f_p.shape)) I_p = wavelet_op_adj.adj_op(f) x_d = np.vdot(Img, I_p) x_ad = np.vdot(f_p, f) np.testing.assert_allclose(x_d, x_ad, rtol=1e-6) print(" Wavelet2 adjoint test passes")
def get_operators( kspace_data, loc, mask, fourier_type=1, max_iter=80, regularisation=None, linear=None, ): """Create the various operators from the config file.""" n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0] shape = kspace_data.shape[-2:] if fourier_type == 0: # offline reconstruction kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data, mask=mask, max_iter=max_iter) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 1: # online type I reconstruction kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 2: # online type II reconstruction kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = ColumnFFT(shape=shape, n_coils=n_coils) else: raise NotImplementedError if linear is None: linear_op = Identity() else: lin_cls = linear.pop("class", None) if lin_cls == "WaveletN": linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear) linear_op.op(np.zeros_like(kspace_data)) elif lin_cls == "Identity": linear_op = Identity() else: raise NotImplementedError prox_op = IdentityProx() if regularisation is not None: reg_cls = regularisation.pop("class") if reg_cls == "LASSO": prox_op = LASSO(weights=regularisation["weights"]) if reg_cls == "GroupLASSO": prox_op = GroupLASSO(weights=regularisation["weights"]) elif reg_cls == "OWL": prox_op = OWL(**regularisation, n_coils=n_coils, bands_shape=linear_op.coeffs_shape) elif reg_cls == "IdentityProx": prox_op = IdentityProx() linear_op = Identity() return kspace_generator, fourier_op, linear_op, prox_op
def test_Wavelet2D_ISAP(self): """Test the adjoint operator for the 2D Wavelet transform """ for ch in self.num_channels: print("Testing with Num Channels : " + str(ch)) for i in range(self.max_iter): print("Process Wavelet2D_ISAP test '{0}'...", i) wavelet_op_adj = WaveletN(wavelet_name="HaarWaveletTransform", nb_scale=4, n_coils=ch, n_jobs=2) Img = np.squeeze( np.random.randn(ch, self.N, self.N) + 1j * np.random.randn(ch, self.N, self.N)) f_p = wavelet_op_adj.op(Img) f = (np.random.randn(*f_p.shape) + 1j * np.random.randn(*f_p.shape)) I_p = wavelet_op_adj.adj_op(f) x_d = np.vdot(Img, I_p) x_ad = np.vdot(f_p, f) np.testing.assert_allclose(x_d, x_ad, rtol=1e-5) print(" Wavelet2 adjoint test passes")
# Calculate SSIM base_ssim = ssim(image_rec0, image) print(base_ssim) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name="sym8", nb_scales=4, dim=3, padding_mode="periodization", ) regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft") # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', verbose=1, ) # Start Reconstruction x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_data, optimization_alg='fista',
def E(**kwargs): # -- # -- Computes the cost of a given mask or parametrisation of mask -- # -- # INPUTS: images: list of all images used to evaluate the reconstruction # kspace_data: list of noised kspace data associated to these images # param: list of lower and upper level parameters. Must contain: # - epsilon: weight of L2 norm in lower level reconstruction # - gamma: parameter for approximation of L1 norm # - c: weight of L for upper level cost function # - beta: weight of P for upper level cost function # Remark: You should choose these parameters to have E(pk)~1 at the beginning (otherwire, L-BFGS-B may stop too early) # fourier_op: Fourier operator for lower level reconstruction # linear_op: Linear operator for lower level reconstruction # mask_type (optional): learn cartesian mask if mask_type="cartesian", otherwise each point is independant. # Other parametrisations may be implemented later. # pk: initial mask. Mandatory if mask_type!="cartesian" # lk: initial mask parametrisation. Mandatory if mask_type="cartesian" # verbose (optional) # Getting parameters images = kwargs.get("images", None) kspace_data = kwargs.get("kspace_data", None) param = kwargs.get("param", None) samples = kwargs.get("samples", []) fourier_op = NonCartesianFFT(samples=samples, shape=images[0].shape, implementation='cpu') wavelet_name = kwargs.get("wavelet_name", "") wavelet_scale = kwargs.get("wavelet_scale", 1) linear_op = WaveletN(wavelet_name=wavelet_name, nb_scale=wavelet_scale, padding_mode="periodization") parallel = kwargs.get("parallel", False) mask_type = kwargs.get("mask_type", "") lk = kwargs.get("lk", None) pk = kwargs.get("pk", None) verbose = kwargs.get("verbose", 0) if verbose >= 0: print("\n\nEVALUATING E(p)") # Checking inputs errors if images is None or len(images) < 1: raise ValueError("At least one image is needed") if param is None: raise ValueError("Lower level parameters must be given") if len(images) != len(kspace_data): raise ValueError("Need as many images and kspace data") #Compute P(pk/lk) Ep = 0 if mask_type == "cartesian": pk = pcart(lk) Ep = P(lk, param["beta"]) elif mask_type == "radial_CO": n_rad = kwargs.get("n_rad") pk = pradCO(lk, n_rad) Ep = P(lk, param["beta"]) else: Ep = P(pk, param["beta"]) #Compute L(pk/lk) Nimages = len(images) if parallel: uk_list = Parallel(n_jobs=-1, verbose=0)( delayed(pdhg)(kspace_data[i], pk, samples=samples, shape=images[0].shape, wavelet_name=wavelet_name, wavelet_scale=wavelet_scale, param=param, mask_type=mask_type, const=kwargs.get("const", {}), verbose=-1) for i in range(Nimages)) Ep += np.sum( [L(uk_list[i][0], images[i], param["c"]) for i in range(Nimages)]) / Nimages else: for i in range(Nimages): if verbose >= 0: print(f"\nImage {i+1}:") u0_mat, y = images[i], kspace_data[i] if verbose > 0: print("\nStarting PDHG") uk, _ = pdhg(y, pk, maxit=50, fourier_op=fourier_op, linear_op=linear_op, **kwargs) Ep += L(uk, u0_mat, param["c"]) / Nimages return Ep
def grad_L(**kwargs): # -- # -- Compute gradient of L with respect to p # -- # INPUTS: pk: Point where we want to compute the gradient # u0_mat: Ground_truth image # y: kspace data associated to u0_mat # param: list of lower and upper level parameters. Must contain: # - epsilon: weight of L2 norm in lower level reconstruction # - gamma: parameter for approximation of L1 norm # - c: weight of L for upper level cost function # - beta: weight of P for upper level cost function # Remark: You should choose these parameters to have E(pk)~1 at the beginning (otherwire, L-BFGS-B may stop too early) # fourier_op: Fourier operator for lower level reconstruction # linear_op: Linear operator for lower level reconstruction # # max_cgiter (optional): maximum number of Conjugate Gradient iterations (default: 3000) # cgtol (optional): tolerance of Conjugate Gradient iterations (default: 1e-6) # compute_conv (optional): plot convergence if True (default: False) # verbose (optional) # -- Getting parameters max_cgiter = kwargs.get("max_cgiter", 4000) cgtol = kwargs.get("cgtol", 1e-6) compute_conv = kwargs.get("compute_conv", False) mask_type = kwargs.get("mask_type", "") learn_mask = kwargs.get("learn_mask", True) learn_alpha = kwargs.get("learn_alpha", True) u0_mat = kwargs.get("u0_mat", None) param = kwargs.get("param", None) y = kwargs.get("y", None) pk = kwargs.get("pk", None) samples = kwargs.get("samples", []) fourier_op = NonCartesianFFT(samples=samples, shape=u0_mat.shape, implementation='cpu') wavelet_name = kwargs.get("wavelet_name", "") wavelet_scale = kwargs.get("wavelet_scale", 1) linear_op = WaveletN(wavelet_name=wavelet_name, nb_scale=wavelet_scale, padding_mode="periodization") const = kwargs.get("const", {}) verbose = kwargs.get("verbose", 0) cg_conv = [] if u0_mat is None: raise ValueError("A ground truth image u0_mat is needed") if y is None: raise ValueError("kspace data y are needed") n = len(u0_mat) # -- Compute uk from pk with lower level solver if not given if verbose >= 0: print("\nStarting PDHG") uk, _ = pdhg(y, pk, mask_type=mask_type, fourier_op=fourier_op, linear_op=linear_op, param=param, maxit=50, verbose=verbose, const=const) # -- Defining linear operator from pk and uk def mv(w): w_complex = np.reshape(w[:n**2] + 1j * w[n**2:], (n, n)) fx = np.reshape( Du2_Etot(uk, pk, w_complex, eps=param["epsilon"], fourier_op=fourier_op, y=y, linear_op=linear_op, gamma=param["gamma"]), (n**2, )) return np.concatenate([np.real(fx), np.imag(fx)]) lin = LinearOperator((2 * n**2, 2 * n**2), matvec=mv) if verbose >= 0: print("\nStarting Conjugate Gradient method") t1 = time.time() B = np.reshape(Du_L(uk, u0_mat, param["c"]), (n**2, )) B_real = np.concatenate([np.real(B), np.imag(B)]) def cgcall(x): #CG callback function to plot convergence if compute_conv: cg_conv.append( np.linalg.norm(lin(x) - B_real) / np.linalg.norm(B_real)) x_inter, _ = cg(lin, B_real, tol=cgtol, maxiter=max_cgiter, callback=cgcall) if verbose >= 0: print( f"Finished in {time.time()-t1}s - ||Ax-b||/||b||: {np.linalg.norm(lin(x_inter)-B_real)/np.linalg.norm(B_real)}" ) # -- Plotting if compute_conv: plt.plot(cg_conv) plt.yscale("log") plt.title("Convergence of the conjugate gradient") plt.xlabel("Number of iterations") plt.ylabel("||Ax-b||/||b||") #plt.savefig("Upper Level/CG_conv.png") if np.linalg.norm(lin(x_inter) - B_real) / np.linalg.norm(B_real) > 1e-3: return np.zeros(pk.shape) else: return -Dpu_Etot(uk, pk, np.reshape(x_inter[:n**2] + 1j * x_inter[n**2:], (n, n)), eps=param["epsilon"], fourier_op=fourier_op, y=y, linear_op=linear_op, gamma=param["gamma"], learn_mask=learn_mask, learn_alpha=learn_alpha)
zero_filled = fourier_op.adj_op(kspace_obs) image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0))) # image_rec0.show() base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name='sym8', nb_scale=4, ) regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft") # Setup Reconstructor reconstructor = SelfCalibrationReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', kspace_portion=0.01, verbose=1, ) x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_obs, optimization_alg='fista', num_iterations=10,
def pdhg(data, p, **kwargs): # -- # -- MAIN LOWER LEVEL FUNCTION # -- # INPUTS: - data: kspace measurements # - p: p[:-1]=subsampling mask S(p), p[-1]=regularisation parameter alpha(p) # So len(p)=len(data)+1 # - fourier_op: fourier operator from a full mask of same shape as the final image. # - linear_op: linear operator used in regularisation functions # For the moment, only use waveletN. # - param: lower level energy parameters # Must contain parameters keys "epsilon" and "gamma". # mask_type (optional): type of mask used ("cartesian", "radial"). Assume a cartesian mask if not given. # -- # OPTIONAL INPUTS: # - const: algorithm constants if we already know the values we want to use for tau and sigma # If not given, will compute them according to what is said in the article. # - compute_energy: bool, we compute and return energy over iterations if True (default: False) # - ground_truth: matrix representing the true image the data come from (default: None). If not None, we compute the ssim over iterations. # - maxit,tol: We stop the algorithm when the norm of the difference between two steps # is smaller than tol or after maxit iterations (default: 200, 1e-4) # -- # OUTPUTS: - uk: final image # - norms(, energy, ssims): evolution of stopping criterion (and energy if compute_energy is True / ssims if ground_truth not None) fourier_op = kwargs.get("fourier_op", None) linear_op = kwargs.get("linear_op", None) param = kwargs.get("param", None) # Create fourier_op and linear_op if not given for multithreading if fourier_op is None: samples = kwargs.get("samples", []) shape = kwargs.get("shape", ()) if samples is not None: fourier_op = NonCartesianFFT(samples=samples, shape=shape, implementation='cpu') if fourier_op is None: raise ValueError("A fourier operator fourier_op must be given") if linear_op is None: wavelet_name = kwargs.get("wavelet_name", "") wavelet_scale = kwargs.get("wavelet_scale", 1) if wavelet_name != "": linear_op = WaveletN(wavelet_name=wavelet_name, nb_scale=wavelet_scale, padding_mode="periodization") if linear_op is None: raise ValueError("A linear operator linear_op must be given") if param is None: raise ValueError("Lower level parameters must be given") mask_type = kwargs.get("mask_type", "") const = kwargs.get("const", {}) compute_energy = kwargs.get("compute_energy", False) ground_truth = kwargs.get("ground_truth", None) maxit = kwargs.get("maxit", 200) tol = kwargs.get("tol", 1e-6) verbose = kwargs.get("verbose", 1) #Global parameters p, pn1 = p[:-1], p[-1] epsilon = param["epsilon"] gamma = param["gamma"] n_iter = 0 #Algorithm constants const = compute_constants(param, const, p) if verbose >= 0: print("Sigma:", const["sigma"], "\nTau:", const["tau"]) #Initializing uk = fourier_op.adj_op(p * data) vk = np.copy(uk) wk = linear_op.op(uk) uk_bar = np.copy(uk) norm = 2 * tol #For plots if compute_energy: energy = [] if ground_truth is not None: ssims = [] norms = [] #Main loop t1 = time.time() while n_iter < maxit and norm > tol: uk, vk, wk, uk_bar, norm = step(uk, vk, wk, uk_bar, const, p, pn1, data, param, linear_op, fourier_op, mask_type) n_iter += 1 #Saving informations norms.append(norm) if compute_energy: energy.append( energy_wavelet(uk, p, pn1, data, gamma, epsilon, linear_op, fourier_op)) if ground_truth is not None: ssims.append(ssim(uk, ground_truth)) #Printing if n_iter % 10 == 0 and verbose > 0: if compute_energy: print(n_iter, " iterations:\nCost:", energy[-1], "\nNorm:", norm, "\n") else: print(n_iter, " iterations:\nNorm:", norm, "\n") if verbose >= 0: print("Finished in", time.time() - t1, "seconds.") #Return if compute_energy and ground_truth is not None: return uk, norms, energy, ssims elif ground_truth is not None: return uk, norms, ssims elif compute_energy: return uk, norms, energy else: return uk, norms
image_rec0 = pysap.Image(data=grid_soln) # image_rec0.show() base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # TODO get the right mu operator # Setup the operators linear_op = WaveletN(wavelet_name="sym8", nb_scales=4, dim=3) regularizer_op = SparseThreshold(Identity(), 6 * 1e-9, thresh_type="soft") # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', verbose=1, ) # Start Reconstruction x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_obs, optimization_alg='fista', num_iterations=10, )
image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0))) # image_rec0.show() base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name='sym8', nb_scale=4, n_coils=cartesian_ref_image.shape[0], ) regularizer_op = GroupLASSO(weights=6e-8) # Setup Reconstructor reconstructor = CalibrationlessReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', verbose=1, ) x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_obs, optimization_alg='fista', num_iterations=300,