def refine_s_end(self): """Perform the finale refinement of the sources, with K = 1. Returns ------- int error code""" if self.verb >= 2: print( "Finale refinement of the sources with the finale estimation of A..." ) if self.cstWuRegStr: strat = 0 else: strat = 1 c = np.min(self.c_wu) update_s = self.update_s(strat, c, 1, doThr=self.thrEnd, doRw=False) if update_s: # error caught return 1 # Initialize attributes self.Swtrw = np.zeros((self.n, self.p, self.nscales)) S_old = np.zeros((self.n, self.p)) delta_S = np.inf it = 0 if not self.keepWuRegStr: strat = 2 c = self.c_ref while delta_S >= self.eps[2] and it < 25: it += 1 update_s = self.update_s(strat, c, 1, doThr=self.thrEnd) if update_s: # error caught return 1 delta_S = np.linalg.norm(S_old - self.S) / np.linalg.norm(self.S) S_old = self.S.copy() if self.A0 is not None and self.S0 is not None and self.verb >= 2: Acp, Scp, _ = utils.corr_perm(self.A0, self.S0, self.A, self.S, optInd=True) if self.verb >= 2: print("NMSE = %.2f" % utils.nmse(self.S0, Scp)) if self.verb >= 2: print("delta_S = %.2e" % delta_S) return 0
def oracle_dss(self, strat, c, S=None, A0=None, iSNR0=None, Swt0=None): """Solve the oracle deconvolution source separation problem. Parameters ---------- strat: int regularization strategy (0: constant, 1: mixing-matrix-based, 2: spectrum-based) c: float regularization hyperparameter S: np.ndarray (n,p) float array, estimated sources (in-place update, default: self.S) A0: np.ndarray (m,n) float array, ground truth mixing matrix (default: self.A0) iSNR0: np.ndarray (n,p) float array, ground truth regularization parameters for strategy #2 (default: self.iSNR0) Swt0: np.ndarray (n,p,nscales) float array, ground truth sources in the wavelet domain (default: self.Swt0) Returns ------- int error code """ update_s = self.update_s(strat, c, 1, doRw=True, S=S, A=A0, iSNR=iSNR0, Swtrw=Swt0, oracle=True) if update_s: # error caught return 1 self.nmse = utils.nmse(self.S0, self.S) return 0
def vaerecon(ksp, coilmaps, mode, vae_model, gt, logdir, device, writer=False, norm=1, nsampl=100, boot_samples=500, k=1, patchsize=28, parfact=25, num_iter=200, stepsize=5e-4, lmb=0.01, num_priors=1, use_momentum=True): # Init data imcoils, imsizer, imsizec = ksp.shape ksp = ksp.to(device) coilmaps = coilmaps.to(device) vae_model = vae_model.to(device) uspat = (torch.abs(ksp[0]) > 0).type(torch.uint8).to(device) recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps) rss = rss_pytorch(ksp) # Init coilmaps estimation with JSENSE if mode == 'JDDP': # Polynomial order max_basis_order = 6 num_coeffs = (max_basis_order + 1)**2 # Create the basis functions for the sense estimation estimation basis_funct = create_basis_functions(imsizer, imsizec, max_basis_order, show_plot=False) plot_basis = False if plot_basis: for i in range(num_coeffs): writer.log({ "Basis funcs": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.from_numpy(basis_funct[i, :, :]))), caption="") ] }) basis_funct = torch.from_numpy( np.tile(basis_funct[np.newaxis, :, :, :], [coilmaps.shape[0], 1, 1, 1])).to(device) coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat) coilmaps = torch.sum( coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct, 1).to(device) recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps) if writer: for i in range(coilmaps.shape[0]): writer.log( { "abs Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.abs(coilmaps[i, :, :]))), caption="") ] }, step=0) writer.log( { "phase Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.angle(coilmaps[i, :, :]))), caption="") ] }, step=0) print("Coilmaps init done") # Log if writer: writer.log( { "Gt rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor(gt)), caption="") ] }, step=0) writer.log( { "Restored rss": [ writer.Image(transforms.ToPILImage()( normalize_tensor(rss)), caption="") ] }, step=0) writer.log( { "Restored abs": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.abs(recs_gpu))), caption="") ] }, step=0) writer.log( { "Restored Phase": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.angle(recs_gpu))), caption="") ] }, step=0) writer.log( { "diff rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor( (rss.detach().cpu() / norm - gt.detach().cpu()))), caption="") ] }, step=0) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) nmse_v = nmse(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) writer.log({"SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v}, step=0) lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model) writer.log({"ELBO": lik}, step=0) writer.log({"DC err": dc}, step=0) t = 1 for it in range(0, num_iter, 2): print('Itr: ', it) # Magnitude prior projection step for _ in range(num_priors): # Gradient descent of Prior if mode == 'TV': tvnorm, abstvgrad = tv_norm(torch.abs(rss)) priorgrad = abstvgrad * recs_gpu / (torch.abs(recs_gpu)) recs_gpu = recs_gpu - stepsize * priorgrad if writer: #and it%10 == 0: writer.log( { "TVgrad": [ writer.Image(transforms.ToPILImage()( normalize_tensor(abstvgrad)), caption="") ] }, step=it + 1) writer.log( { "TV": [ writer.Image(transforms.ToPILImage()( normalize_tensor(tvnorm)), caption="") ] }, step=it + 1) elif mode == 'DDP' or mode == 'JDDP': g_abs_lik, est_uncert, g_dc = prior_gradient( rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model, boot_samples, mode) priorgrad = g_abs_lik * recs_gpu / (torch.abs(recs_gpu)) if it > -1: recs_gpu = recs_gpu - stepsize * priorgrad if writer: # Log writer.log( { "VAEgrad abs": [ writer.Image(transforms.ToPILImage()( normalize_tensor(torch.abs(g_abs_lik))), caption="") ] }, step=it + 1) writer.log({"STD": torch.mean(torch.abs(est_uncert))}, step=it + 1) tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps) tmp2 = ksp * uspat.unsqueeze(0) tmp = tmp1 + tmp2 rss = rss_pytorch(tmp) nmse_v = nmse( (rss[160:-160].detach().cpu().numpy() / norm), gt[160:-160].detach().cpu().numpy()) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) writer.log({ "SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v }, step=it + 1) else: print("Error: Prior method does not exists.") exit() # Phase projection step if lmb > 0: tmpa = torch.abs(recs_gpu) tmpp = torch.angle(recs_gpu) # We apply phase regularization to prefer smooth phase images #tmpptv = reg2_proj(tmpp, imsizer, imsizec, alpha=lmb, niter=2) # 0.1, 15 tmpptv = tv_proj(tmpp, mu=0.125, lmb=lmb, IT=50) # 0.1, 15 # We combine back the phase and the magnitude recs_gpu = tmpa * torch.exp(1j * tmpptv) # Coilmaps estimation step (if JSENSE) if mode == 'JDDP': # computed on cpu since pytorch gpu can handle complex numbers... coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat) coilmaps = torch.sum( coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct, 1).to(device) if writer: writer.log( { "abs Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.abs(coilmaps[0, :, :]))), caption="") ] }, step=it + 1) writer.log( { "phase Coilmaps": [ writer.Image( transforms.ToPILImage()(normalize_tensor( torch.angle(coilmaps[0, :, :]))), caption="") ] }, step=it + 1) # Data consistency projection tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps) tmp2 = ksp * uspat.unsqueeze(0) tmp = tmp1 + tmp2 recs_gpu = tFT_pytorch(tmp, coilmaps) # recs[it + 2] = recs_gpu.detach().cpu().numpy() rss = rss_pytorch(tmp) # Log nmse_v = nmse((rss[160:-160].detach().cpu().numpy() / norm), gt[160:-160].detach().cpu().numpy()) ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm, gt[160:-160].detach().cpu().numpy()) print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v) if writer: writer.log({ "SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v }, step=it + 1) writer.log( { "Restored rss": [ writer.Image(transforms.ToPILImage()( normalize_tensor(rss)), caption="") ] }, step=it + 1) writer.log( { "Restored Phase": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.angle(recs_gpu))), caption="") ] }, step=it + 1) writer.log( { "diff rss": [ writer.Image(transforms.ToPILImage()(normalize_tensor( (rss.detach().cpu() / norm - gt.detach().cpu()))), caption="") ] }, step=it + 1) writer.log( { "Restored 1ch kspace": [ writer.Image(transforms.ToPILImage()(normalize_tensor( torch.log(torch.abs(tmp[0])))), caption="") ] }, step=it + 1) lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl, vae_model) writer.log({"ELBO": lik}, step=it + 1) writer.log({"DC err": dc}, step=it + 1) return rss / norm
def core(self): """Manage the separation. This function handles the alternate updates of S and A, as well as the two stages (warm-up and refinement). Returns ------- int error code """ stage = "wu" S_old = np.zeros((self.n, self.p)) A_old = np.zeros((self.m, self.n)) it = 0 while True: it += 1 # Get parameters of DecGMCA for the current iteration strat, c, K, doRw, nnegS = self.get_parameters(stage, it) if self.verb >= 2: print("Iteration #%i" % it) # Update S update_s = self.update_s(strat, c, K, doRw=doRw, nnegS=nnegS) if update_s: # error caught return 1 # Update A update_a = self.update_a() if update_a: # error caught return 1 # Post processing delta_S = np.linalg.norm(S_old - self.S) / np.linalg.norm(self.S) delta_A = np.max(abs(1 - abs(np.sum(self.A * A_old, axis=0)))) cond_A = np.linalg.cond(self.A) S_old = self.S.copy() A_old = self.A.copy() if self.A0 is not None and self.S0 is not None and self.verb >= 2: Acp, Scp, _ = utils.corr_perm(self.A0, self.S0, self.A, self.S, optInd=True) if self.verb >= 2: print("NMSE = %.2f - CA = %.2f" % (utils.nmse(self.S0, Scp), utils.ca(self.A0, Acp))) if self.verb >= 2: print("delta_S = %.2e - delta_A = %.2e - cond(A) = %.2f" % (delta_S, delta_A, cond_A)) if self.verb >= 5: print("A:\n", self.A) # Stage update if stage == 'wu' and it >= self.minWuIt and ( delta_S <= self.eps[0] or it >= self.minWuIt + 50): if self.verb >= 2: print("> End of the warm-up (iteration %i)" % it) self.lastWuIt = it stage = 'ref' if stage == 'ref' and (delta_S <= self.eps[1] or it >= self.lastWuIt + 50) and ( it >= self.lastWuIt + 25): if self.verb >= 2: print("> End of the refinement (iteration %i)" % it) self.lastRefIt = it return 0
plt.figure() plt.pcolormesh(np.rad2deg(phi_target), freq / 1000, db(np.fft.rfft(h0[mm] - hhat[mm], axis=-1)).T) plt.axis('normal') plt.colorbar(label='dB') plt.clim(-200, 0) plt.xlabel(r'$\phi$ / $^\circ$') plt.ylabel(r'$f$ / kHz') plt.xlim(0, 360) plt.ylim(0, fs / 2 / 1000) plt.title('Spectral Distortion (Loudspeaker #{})'.format(idx[mm])) # Fig. Normalized system distance in dB plt.figure() for m in range(M): plt.plot(np.rad2deg(phi_target), db(nmse(hhat[m], h0[m]))) plt.xlim(0, 360) plt.ylim(-200, 0) plt.xlabel(r'$\phi$ / $^\circ$') plt.ylabel('NMSE / dB') plt.title('Normalized Mean Square Error (Loudspeaker #{})'.format(idx[mm])) # Fig. Desired CHT spectrum plt.figure(figsize=(10, 4)) plt.pcolormesh(order, freq / 1000, db(np.fft.fftshift(np.fft.fft2(h0[mm]), axes=0)[:, :Nf]).T, vmin=-120) plt.axis('normal') plt.xlabel('CHT order') plt.ylabel(r'$f$ / kHz') plt.colorbar(label='dB')
def run_inference(subj, R, mode, k, num_sampels, num_bootsamles, batch_size, num_iter, step_size, phase_step, complex_rec, use_momentum, log, device): # Some inits of paths... Edit these vae_model_name = 'T2-20210415-111101/450.pth' vae_path = '/cluster/scratch/jonatank/logs/ddp/vae/' data_path = '/cluster/work/cvl/jonatank/fastMRI_T2/validation/' log_path = '/cluster/scratch/jonatank/logs/ddp/restore/pytorch/' rss = True # Load pretrained VAE path = vae_path + vae_model_name vae = torch.load(path, map_location=torch.device(device)) vae.eval() # Data loader setup subj_dataset = Subject(subj, data_path, R, rss=rss) subj_loader = data.DataLoader(subj_dataset, batch_size=1, shuffle=False, num_workers=0) # Time model and init resulting matrices start_time = time.perf_counter() rec_subj = np.zeros((len(subj_loader), 320, 320)) gt_subj = np.zeros((len(subj_loader), 320, 320)) # Set basic parameters print('Subj: ', subj, ' R: ', R, ' mode: ', mode, ' k: ', k, ' num_sampels: ', num_sampels, ' num_bootsamles: ', num_bootsamles, ' batch_size: ', batch_size, ' num_iter: ', num_iter, ' step_size: ', step_size, ' phase_step: ', phase_step) # Log log_path = log_path + 'R' + str(R) + '_mode' + str( k) + mode + '_reg2lmb0.01_' + datetime.now().strftime("%Y%m%d-%H%M%S") if log: import wandb wandb.login() wandb.init(project='JDDP' + '_T2', name=vae_model_name, config={ "num_iter": num_iter, "step_size": step_size, "phase_step": phase_step, "mode": mode, 'R': R, 'K': k, 'use_momentum': use_momentum }) #wandb.watch(vae) else: wandb = False print("num_iter", num_iter, " step_size ", step_size, " phase_step ", phase_step, " mode ", mode, ' R ', R, ' K ', k, 'use_momentum', use_momentum) for batch in tqdm(subj_loader, desc="Running inference"): ksp, coilmaps, rss, norm_fact, num_sli = batch rec_sli = vaerecon(ksp[0], coilmaps[0], mode, vae, rss[0], log_path, device, writer=wandb, norm=norm_fact.item(), nsampl=num_sampels, boot_samples=num_bootsamles, k=k, patchsize=28, parfact=batch_size, num_iter=num_iter, stepsize=step_size, lmb=phase_step, use_momentum=use_momentum) rec_subj[num_sli] = np.abs(center_crop(rec_sli.detach().cpu().numpy())) gt_subj[num_sli] = np.abs(center_crop(rss[0])) rmse_sli = nmse(rec_subj[num_sli], gt_subj[num_sli]) ssim_sli = ssim(rec_subj[num_sli], gt_subj[num_sli]) psnr_sli = psnr(rec_subj[num_sli], gt_subj[num_sli]) print('Slice: ', num_sli.item(), ' RMSE: ', str(rmse_sli), ' SSIM: ', str(ssim_sli), ' PSNR: ', str(psnr_sli)) end_time = time.perf_counter() print(f"Elapsed time for {str(num_sli)} slices: {end_time-start_time}") rmse_v = nmse(recon_subj, gt_subj) ssim_v = nmse(recon_subj, gt_subj) psnr_v = nmse(recon_subj, gt_subj) print('Subject Done: ', 'RMSE: ', str(rmse_sli), ' SSIM: ', str(ssim_sli), ' PSNR: ', str(psnr_sli)) pickle.dump( recon_subj, open(log_path + subj + str(k) + mode + str(restore_sense) + str(R), 'wb')) end_time = time.perf_counter() print(f"Elapsed time for {len(subj_loader)} slices: {end_time-start_time}")