Exemple #1
0
    def refine_s_end(self):
        """Perform the finale refinement of the sources, with K = 1.

        Returns
        -------
        int
            error code"""

        if self.verb >= 2:
            print(
                "Finale refinement of the sources with the finale estimation of A..."
            )

        if self.cstWuRegStr:
            strat = 0
        else:
            strat = 1
        c = np.min(self.c_wu)

        update_s = self.update_s(strat, c, 1, doThr=self.thrEnd, doRw=False)
        if update_s:  # error caught
            return 1

        # Initialize attributes
        self.Swtrw = np.zeros((self.n, self.p, self.nscales))
        S_old = np.zeros((self.n, self.p))
        delta_S = np.inf
        it = 0

        if not self.keepWuRegStr:
            strat = 2
            c = self.c_ref

        while delta_S >= self.eps[2] and it < 25:
            it += 1

            update_s = self.update_s(strat, c, 1, doThr=self.thrEnd)
            if update_s:  # error caught
                return 1

            delta_S = np.linalg.norm(S_old - self.S) / np.linalg.norm(self.S)
            S_old = self.S.copy()

            if self.A0 is not None and self.S0 is not None and self.verb >= 2:
                Acp, Scp, _ = utils.corr_perm(self.A0,
                                              self.S0,
                                              self.A,
                                              self.S,
                                              optInd=True)
                if self.verb >= 2:
                    print("NMSE = %.2f" % utils.nmse(self.S0, Scp))

            if self.verb >= 2:
                print("delta_S = %.2e" % delta_S)

        return 0
Exemple #2
0
    def oracle_dss(self, strat, c, S=None, A0=None, iSNR0=None, Swt0=None):
        """Solve the oracle deconvolution source separation problem.

        Parameters
        ----------
        strat: int
            regularization strategy (0: constant, 1: mixing-matrix-based, 2: spectrum-based)
        c: float
            regularization hyperparameter
        S: np.ndarray
            (n,p) float array, estimated sources (in-place update, default: self.S)
        A0: np.ndarray
            (m,n) float array, ground truth mixing matrix (default: self.A0)
        iSNR0: np.ndarray
            (n,p) float array, ground truth regularization parameters for strategy #2 (default: self.iSNR0)
        Swt0: np.ndarray
            (n,p,nscales) float array, ground truth sources in the wavelet domain (default: self.Swt0)

        Returns
        -------
        int
            error code
        """

        update_s = self.update_s(strat,
                                 c,
                                 1,
                                 doRw=True,
                                 S=S,
                                 A=A0,
                                 iSNR=iSNR0,
                                 Swtrw=Swt0,
                                 oracle=True)
        if update_s:  # error caught
            return 1

        self.nmse = utils.nmse(self.S0, self.S)

        return 0
Exemple #3
0
def vaerecon(ksp,
             coilmaps,
             mode,
             vae_model,
             gt,
             logdir,
             device,
             writer=False,
             norm=1,
             nsampl=100,
             boot_samples=500,
             k=1,
             patchsize=28,
             parfact=25,
             num_iter=200,
             stepsize=5e-4,
             lmb=0.01,
             num_priors=1,
             use_momentum=True):
    # Init data
    imcoils, imsizer, imsizec = ksp.shape
    ksp = ksp.to(device)
    coilmaps = coilmaps.to(device)
    vae_model = vae_model.to(device)
    uspat = (torch.abs(ksp[0]) > 0).type(torch.uint8).to(device)
    recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)
    rss = rss_pytorch(ksp)

    # Init coilmaps estimation with JSENSE
    if mode == 'JDDP':
        # Polynomial order
        max_basis_order = 6
        num_coeffs = (max_basis_order + 1)**2

        # Create the basis functions for the sense estimation estimation
        basis_funct = create_basis_functions(imsizer,
                                             imsizec,
                                             max_basis_order,
                                             show_plot=False)
        plot_basis = False
        if plot_basis:
            for i in range(num_coeffs):
                writer.log({
                    "Basis funcs": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.from_numpy(basis_funct[i, :, :]))),
                                     caption="")
                    ]
                })

        basis_funct = torch.from_numpy(
            np.tile(basis_funct[np.newaxis, :, :, :],
                    [coilmaps.shape[0], 1, 1, 1])).to(device)
        coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct, uspat)

        coilmaps = torch.sum(
            coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
            1).to(device)

        recs_gpu = tUFT_pytorch(ksp, uspat, coilmaps)

        if writer:
            for i in range(coilmaps.shape[0]):
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[i, :, :]))),
                                caption="")
                        ]
                    },
                    step=0)
        print("Coilmaps init done")

    # Log
    if writer:
        writer.log(
            {
                "Gt rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(gt)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored rss": [
                    writer.Image(transforms.ToPILImage()(
                        normalize_tensor(rss)),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored abs": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.abs(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "Restored Phase": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        torch.angle(recs_gpu))),
                                 caption="")
                ]
            },
            step=0)
        writer.log(
            {
                "diff rss": [
                    writer.Image(transforms.ToPILImage()(normalize_tensor(
                        (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                 caption="")
                ]
            },
            step=0)
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        nmse_v = nmse(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)
        writer.log({"SSIM": ssim_v, "NMSE": nmse_v, "PSNR": psnr_v}, step=0)

        lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize, parfact,
                              nsampl, vae_model)
        writer.log({"ELBO": lik}, step=0)
        writer.log({"DC err": dc}, step=0)

    t = 1
    for it in range(0, num_iter, 2):
        print('Itr: ', it)

        # Magnitude prior projection step
        for _ in range(num_priors):
            # Gradient descent of Prior
            if mode == 'TV':
                tvnorm, abstvgrad = tv_norm(torch.abs(rss))
                priorgrad = abstvgrad * recs_gpu / (torch.abs(recs_gpu))
                recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  #and it%10 == 0:
                    writer.log(
                        {
                            "TVgrad": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(abstvgrad)),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log(
                        {
                            "TV": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(tvnorm)),
                                             caption="")
                            ]
                        },
                        step=it + 1)

            elif mode == 'DDP' or mode == 'JDDP':
                g_abs_lik, est_uncert, g_dc = prior_gradient(
                    rss, ksp, uspat, coilmaps, patchsize, parfact, nsampl,
                    vae_model, boot_samples, mode)
                priorgrad = g_abs_lik * recs_gpu / (torch.abs(recs_gpu))

                if it > -1:
                    recs_gpu = recs_gpu - stepsize * priorgrad

                if writer:  # Log
                    writer.log(
                        {
                            "VAEgrad abs": [
                                writer.Image(transforms.ToPILImage()(
                                    normalize_tensor(torch.abs(g_abs_lik))),
                                             caption="")
                            ]
                        },
                        step=it + 1)
                    writer.log({"STD": torch.mean(torch.abs(est_uncert))},
                               step=it + 1)

                    tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
                    tmp2 = ksp * uspat.unsqueeze(0)
                    tmp = tmp1 + tmp2
                    rss = rss_pytorch(tmp)
                    nmse_v = nmse(
                        (rss[160:-160].detach().cpu().numpy() / norm),
                        gt[160:-160].detach().cpu().numpy())
                    ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                                  gt[160:-160].detach().cpu().numpy())
                    print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ',
                          psnr_v)
                    writer.log({
                        "SSIM": ssim_v,
                        "NMSE": nmse_v,
                        "PSNR": psnr_v
                    },
                               step=it + 1)
            else:
                print("Error: Prior method does not exists.")
                exit()

        # Phase projection step
        if lmb > 0:
            tmpa = torch.abs(recs_gpu)
            tmpp = torch.angle(recs_gpu)

            # We apply phase regularization to prefer smooth phase images
            #tmpptv = reg2_proj(tmpp, imsizer, imsizec, alpha=lmb, niter=2)  # 0.1, 15
            tmpptv = tv_proj(tmpp, mu=0.125, lmb=lmb, IT=50)  # 0.1, 15
            # We combine back the phase and the magnitude
            recs_gpu = tmpa * torch.exp(1j * tmpptv)

        # Coilmaps estimation step (if JSENSE)
        if mode == 'JDDP':
            # computed on cpu since pytorch gpu can handle complex numbers...
            coeffs_array = sense_estimation_ls(ksp, recs_gpu, basis_funct,
                                               uspat)
            coilmaps = torch.sum(
                coeffs_array[:, :, np.newaxis, np.newaxis] * basis_funct,
                1).to(device)

            if writer:
                writer.log(
                    {
                        "abs Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.abs(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)
                writer.log(
                    {
                        "phase Coilmaps": [
                            writer.Image(
                                transforms.ToPILImage()(normalize_tensor(
                                    torch.angle(coilmaps[0, :, :]))),
                                caption="")
                        ]
                    },
                    step=it + 1)

        # Data consistency projection
        tmp1 = UFT_pytorch(recs_gpu, 1 - uspat, coilmaps)
        tmp2 = ksp * uspat.unsqueeze(0)
        tmp = tmp1 + tmp2
        recs_gpu = tFT_pytorch(tmp, coilmaps)
        # recs[it + 2] = recs_gpu.detach().cpu().numpy()
        rss = rss_pytorch(tmp)

        # Log
        nmse_v = nmse((rss[160:-160].detach().cpu().numpy() / norm),
                      gt[160:-160].detach().cpu().numpy())
        ssim_v = ssim(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        psnr_v = psnr(rss[160:-160].detach().cpu().numpy() / norm,
                      gt[160:-160].detach().cpu().numpy())
        print('SSIM: ', ssim_v, ' NMSE: ', nmse_v, ' PSNR: ', psnr_v)

        if writer:
            writer.log({
                "SSIM": ssim_v,
                "NMSE": nmse_v,
                "PSNR": psnr_v
            },
                       step=it + 1)
            writer.log(
                {
                    "Restored rss": [
                        writer.Image(transforms.ToPILImage()(
                            normalize_tensor(rss)),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored Phase": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.angle(recs_gpu))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "diff rss": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            (rss.detach().cpu() / norm - gt.detach().cpu()))),
                                     caption="")
                    ]
                },
                step=it + 1)
            writer.log(
                {
                    "Restored 1ch kspace": [
                        writer.Image(transforms.ToPILImage()(normalize_tensor(
                            torch.log(torch.abs(tmp[0])))),
                                     caption="")
                    ]
                },
                step=it + 1)
            lik, dc = prior_value(rss, ksp, uspat, coilmaps, patchsize,
                                  parfact, nsampl, vae_model)
            writer.log({"ELBO": lik}, step=it + 1)
            writer.log({"DC err": dc}, step=it + 1)

    return rss / norm
Exemple #4
0
    def core(self):
        """Manage the separation.

        This function handles the alternate updates of S and A, as well as the two stages (warm-up and refinement).

        Returns
        -------
        int
            error code
        """

        stage = "wu"

        S_old = np.zeros((self.n, self.p))
        A_old = np.zeros((self.m, self.n))
        it = 0

        while True:
            it += 1
            # Get parameters of DecGMCA for the current iteration
            strat, c, K, doRw, nnegS = self.get_parameters(stage, it)

            if self.verb >= 2:
                print("Iteration #%i" % it)

            # Update S
            update_s = self.update_s(strat, c, K, doRw=doRw, nnegS=nnegS)
            if update_s:  # error caught
                return 1

            # Update A
            update_a = self.update_a()
            if update_a:  # error caught
                return 1

            # Post processing

            delta_S = np.linalg.norm(S_old - self.S) / np.linalg.norm(self.S)
            delta_A = np.max(abs(1 - abs(np.sum(self.A * A_old, axis=0))))
            cond_A = np.linalg.cond(self.A)
            S_old = self.S.copy()
            A_old = self.A.copy()

            if self.A0 is not None and self.S0 is not None and self.verb >= 2:
                Acp, Scp, _ = utils.corr_perm(self.A0,
                                              self.S0,
                                              self.A,
                                              self.S,
                                              optInd=True)
                if self.verb >= 2:
                    print("NMSE = %.2f  -  CA = %.2f" %
                          (utils.nmse(self.S0, Scp), utils.ca(self.A0, Acp)))

            if self.verb >= 2:
                print("delta_S = %.2e - delta_A = %.2e - cond(A) = %.2f" %
                      (delta_S, delta_A, cond_A))
            if self.verb >= 5:
                print("A:\n", self.A)

            # Stage update

            if stage == 'wu' and it >= self.minWuIt and (
                    delta_S <= self.eps[0] or it >= self.minWuIt + 50):
                if self.verb >= 2:
                    print("> End of the warm-up (iteration %i)" % it)
                self.lastWuIt = it
                stage = 'ref'

            if stage == 'ref' and (delta_S <= self.eps[1]
                                   or it >= self.lastWuIt + 50) and (
                                       it >= self.lastWuIt + 25):
                if self.verb >= 2:
                    print("> End of the refinement (iteration %i)" % it)
                self.lastRefIt = it
                return 0
Exemple #5
0
plt.figure()
plt.pcolormesh(np.rad2deg(phi_target), freq / 1000,
               db(np.fft.rfft(h0[mm] - hhat[mm], axis=-1)).T)
plt.axis('normal')
plt.colorbar(label='dB')
plt.clim(-200, 0)
plt.xlabel(r'$\phi$ / $^\circ$')
plt.ylabel(r'$f$ / kHz')
plt.xlim(0, 360)
plt.ylim(0, fs / 2 / 1000)
plt.title('Spectral Distortion (Loudspeaker #{})'.format(idx[mm]))

# Fig. Normalized system distance in dB
plt.figure()
for m in range(M):
    plt.plot(np.rad2deg(phi_target), db(nmse(hhat[m], h0[m])))
plt.xlim(0, 360)
plt.ylim(-200, 0)
plt.xlabel(r'$\phi$ / $^\circ$')
plt.ylabel('NMSE / dB')
plt.title('Normalized Mean Square Error (Loudspeaker #{})'.format(idx[mm]))

# Fig. Desired CHT spectrum
plt.figure(figsize=(10, 4))
plt.pcolormesh(order, freq / 1000,
               db(np.fft.fftshift(np.fft.fft2(h0[mm]), axes=0)[:, :Nf]).T,
               vmin=-120)
plt.axis('normal')
plt.xlabel('CHT order')
plt.ylabel(r'$f$ / kHz')
plt.colorbar(label='dB')
def run_inference(subj, R, mode, k, num_sampels, num_bootsamles, batch_size,
                  num_iter, step_size, phase_step, complex_rec, use_momentum,
                  log, device):
    # Some inits of paths... Edit these
    vae_model_name = 'T2-20210415-111101/450.pth'
    vae_path = '/cluster/scratch/jonatank/logs/ddp/vae/'
    data_path = '/cluster/work/cvl/jonatank/fastMRI_T2/validation/'
    log_path = '/cluster/scratch/jonatank/logs/ddp/restore/pytorch/'
    rss = True

    # Load pretrained VAE
    path = vae_path + vae_model_name
    vae = torch.load(path, map_location=torch.device(device))
    vae.eval()

    # Data loader setup
    subj_dataset = Subject(subj, data_path, R, rss=rss)
    subj_loader = data.DataLoader(subj_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)

    # Time model and init resulting matrices
    start_time = time.perf_counter()
    rec_subj = np.zeros((len(subj_loader), 320, 320))
    gt_subj = np.zeros((len(subj_loader), 320, 320))

    # Set basic parameters
    print('Subj: ', subj, ' R: ', R, ' mode: ', mode, ' k: ', k,
          ' num_sampels: ', num_sampels, ' num_bootsamles: ', num_bootsamles,
          ' batch_size: ', batch_size, ' num_iter: ', num_iter, ' step_size: ',
          step_size, ' phase_step: ', phase_step)

    # Log
    log_path = log_path + 'R' + str(R) + '_mode' + str(
        k) + mode + '_reg2lmb0.01_' + datetime.now().strftime("%Y%m%d-%H%M%S")
    if log:
        import wandb
        wandb.login()
        wandb.init(project='JDDP' + '_T2',
                   name=vae_model_name,
                   config={
                       "num_iter": num_iter,
                       "step_size": step_size,
                       "phase_step": phase_step,
                       "mode": mode,
                       'R': R,
                       'K': k,
                       'use_momentum': use_momentum
                   })
        #wandb.watch(vae)
    else:
        wandb = False

    print("num_iter", num_iter, " step_size ", step_size, " phase_step ",
          phase_step, " mode ", mode, ' R ', R, ' K ', k, 'use_momentum',
          use_momentum)

    for batch in tqdm(subj_loader, desc="Running inference"):
        ksp, coilmaps, rss, norm_fact, num_sli = batch

        rec_sli = vaerecon(ksp[0],
                           coilmaps[0],
                           mode,
                           vae,
                           rss[0],
                           log_path,
                           device,
                           writer=wandb,
                           norm=norm_fact.item(),
                           nsampl=num_sampels,
                           boot_samples=num_bootsamles,
                           k=k,
                           patchsize=28,
                           parfact=batch_size,
                           num_iter=num_iter,
                           stepsize=step_size,
                           lmb=phase_step,
                           use_momentum=use_momentum)

        rec_subj[num_sli] = np.abs(center_crop(rec_sli.detach().cpu().numpy()))
        gt_subj[num_sli] = np.abs(center_crop(rss[0]))

        rmse_sli = nmse(rec_subj[num_sli], gt_subj[num_sli])
        ssim_sli = ssim(rec_subj[num_sli], gt_subj[num_sli])
        psnr_sli = psnr(rec_subj[num_sli], gt_subj[num_sli])
        print('Slice: ', num_sli.item(), ' RMSE: ', str(rmse_sli), ' SSIM: ',
              str(ssim_sli), ' PSNR: ', str(psnr_sli))
        end_time = time.perf_counter()

        print(f"Elapsed time for {str(num_sli)} slices: {end_time-start_time}")

    rmse_v = nmse(recon_subj, gt_subj)
    ssim_v = nmse(recon_subj, gt_subj)
    psnr_v = nmse(recon_subj, gt_subj)
    print('Subject Done: ', 'RMSE: ', str(rmse_sli), ' SSIM: ', str(ssim_sli),
          ' PSNR: ', str(psnr_sli))

    pickle.dump(
        recon_subj,
        open(log_path + subj + str(k) + mode + str(restore_sense) + str(R),
             'wb'))

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(subj_loader)} slices: {end_time-start_time}")