Example #1
0
    def update_patcher(self):
        """ initialize and update patcher """
        def need_reget_label():
            """ check if need to get label from patcher """
            # check if haven't got labels from patcher
            if not "labels" in self.database[self.index]:
                return True
            # check if have got labels from patcher, but with None label file
            # note: there is a case that this will lead to excess label getting,
            # that is, patcher.labels contains no labels.
            count = 0
            for class_i,boxes in self.database[self.index]["labels"].items():
                count += len(boxes)
            return count == 0

        # in case when patcher initialized without label file, need to reinitialize
        if need_reget_label():
            self.patcher = Patcher(self.database[self.index]["fname"], self.database[self.index]["lname"])
            self.database[self.index]["labels"] = self.patcher.get_labels()
Example #2
0
    def __init__(self):
        self.trainptchs = []
        self.traincount = 1
        self.testptchs = []
        self.testcount = 1
        sumptraintchs = []
        sumptesttchs = []

        orgtraindata = getData(trnTst='training')
        print(orgtraindata.shape)
        orgtestdata = getData(trnTst='testing')

        print('orgtraindata', orgtraindata.dtype, type(orgtraindata))

        Ptchr = Patcher(imsize=[orgtraindata.shape[1], orgtraindata.shape[2]],
                        patchsize=28,
                        step=int(28 / 2),
                        nopartials=True,
                        contatedges=True)

        #Ptchr1=Patcher(imsize=[orgtestdata.shape[1],orgtestdata.shape[2]],patchsize=28,step=int(28/2), nopartials=True, contatedges=True)
        for i in range(500):
            ptchs_1 = Ptchr.im2patches(
                add_Usp_nous(orgtraindata[i, :, :, :]).reshape([256, 256]))
            sumptraintchs += ptchs_1
            #orgtraindata[i,:,:,0] = np.reshape(add_Usp_nous(orgtraindata[i,:,:,:]).reshape([64,64]), [64,64])
        for i in range(4):
            ptchs_1 = Ptchr.im2patches(
                add_Usp_nous(orgtestdata[i, :, :, :]).reshape([256, 256]))
            sumptesttchs += ptchs_1
        sumptraintchs = np.array(sumptraintchs)
        sumptraintchs = sumptraintchs.reshape(-1, 784)
        print(sumptraintchs.shape)
        self.trainptchs = sumptraintchs

        sumptesttchs = np.array(sumptesttchs)
        sumptesttchs = sumptesttchs.reshape(-1, 784)
        self.testptchs = sumptesttchs
Example #3
0
for (option, enum_name, requires_parameter) in opts_long:
    if first:
        arg_enum.append("	%s = 1000," % (enum_name))
        first = False
    else:
        arg_enum.append("	%s," % (enum_name))
arg_enum.append("};")
arg_enum = "\n".join(arg_enum) + "\n"

cmd_def = ["	const char *short_options = \"%s\";" % (short_string)]
cmd_def.append("	struct option long_options[] = {")
for (option, enum_name, requires_parameter) in opts_long:
    param = "\"%s\"," % (option)
    if requires_parameter:
        cmd_def.append("		{ %-30s required_argument, 0, %s }," %
                       (param, enum_name))
    else:
        cmd_def.append("		{ %-30s no_argument,       0, %s }," %
                       (param, enum_name))
cmd_def.append("		{ 0 }")
cmd_def.append("	};")
cmd_def = "\n".join(cmd_def) + "\n"

patcher = Patcher("../pgmopts.c")
patcher.patch("help page", help_code)
patcher.patch("command definition enum", arg_enum)
patcher.patch("command definition", cmd_def)

patcher = Patcher("../README.md", filetype="markdown")
patcher.patch("help page", markdown_help_page)
Example #4
0
def vaerecon(us_ksp_r2,
             sensmaps,
             dcprojiter,
             n=10,
             lat_dim=60,
             patchsize=28,
             contRec='',
             parfact=10,
             num_iter=302,
             rescaled=False,
             half=False,
             regiter=15,
             reglmb=0.1,
             regtype='reg2_dc',
             usemeth=1,
             stepsize=1e-4,
             optScale=False,
             mode=[],
             chunks40=False,
             Melmodels='',
             N4BFcorr=False,
             z_multip=1.0,
             directapprox=0,
             vae_model='',
             logdir='',
             directapp=0,
             gt=None):
    print('xxxxxxxxxxxxxxxxxxx contRec is ' + contRec)
    print('xxxxxxxxxxxxxxxxxxx parfact is ' + str(parfact))
    import pickle

    # set parameters
    # ==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]
    imrizec = us_ksp_r2.shape[1]

    nsampl = 50  # 0

    # make a network and a patcher to use later
    # ==============================================================================

    x_rec, x_inp, funop, grd0, grd_dir, sess, grd_p_x_z0, grd_p_z0, grd_q_z_x0, grd20, y_out, y_out_prec, z_std_multip, op_q_z_x, mu, std, grd_q_zpl_x_az0, op_q_zpl_x, z_pl, z = definevae(
        lat_dim=lat_dim,
        patchsize=patchsize,
        mode=mode,
        vae_model=vae_model,
        batchsize=parfact * nsampl)

    if directapp:
        print('_____DIRECT APPROX_____')
        grd0 = grd_dir

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)

    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    # define the necessary functions
    # ==============================================================================

    def FT(x):
        # inp: [nx, ny]
        # out: [nx, ny, ns]
        return np.fft.fftshift(np.fft.fft2(
            sensmaps * np.tile(x[:, :, np.newaxis], [1, 1, sensmaps.shape[2]]),
            axes=(0, 1)),
                               axes=(0, 1))

    # def tFT(x):
    #        # inp: [nx, ny, ns]
    #        # out: [nx, ny]
    #        tft_x = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1)) * np.conjugate(sensmaps)
    #
    #        rss = np.sqrt(np.sum(np.square(tft_x), axis=2))
    #
    #        rss = rss / (np.sqrt(np.sum(np.square(sensmaps*np.conjugate(sensmaps)),axis=2)) + 0.00000001)
    #
    #        return rss # root-sum-squared

    def tFT(x):
        # inp: [nx, ny, ns]
        # out: [nx, ny]

        temp = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1))
        return np.sum(temp * np.conjugate(sensmaps), axis=2) / (
            np.sum(sensmaps * np.conjugate(sensmaps), axis=2) + 0.00000001)

    def UFT(x, uspat):
        # inp: [nx, ny], [nx, ny]
        # out: [nx, ny, ns]

        return np.tile(uspat[:, :, np.newaxis],
                       [1, 1, sensmaps.shape[2]]) * FT(x)

    def tUFT(x, uspat):
        # inp: [nx, ny], [nx, ny]
        # out: [nx, ny]

        tmp1 = np.tile(uspat[:, :, np.newaxis], [1, 1, sensmaps.shape[2]])

        return tFT(tmp1 * x)

    def dconst(us):
        # inp: [nx, ny]
        # out: [nx, ny]

        return np.linalg.norm(UFT(us, uspat) - data)**2

    def dconst_grad(us):
        # inp: [nx, ny]
        # out: [nx, ny]
        return 2 * tUFT(UFT(us, uspat) - data, uspat)

    def likelihood(us):
        # inp: [parfact,ps*ps]
        # out: parfact
        us = np.abs(us)
        funeval = funop.eval(feed_dict={
            x_rec: np.tile(us, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(us,(nsampl,1))
        # funeval: [500x1]
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def likelihood_grad(us):
        # inp: [parfact, ps*ps]
        # out: [parfact, ps*ps]
        usc = us.copy()
        usabs = np.abs(us)

        grd0eval = grd0.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        # grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]

        sigmaeval = y_out_prec.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        sigmaeval = np.array(np.split(sigmaeval, nsampl,
                                      axis=0))  # [nsampl x parfact x 784]

        mueval = y_out.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        mueval = np.array(np.split(mueval, nsampl,
                                   axis=0))  # [nsampl x parfact x 784]

        #vareval = np.std(mueval, axis=0)  # V(MU(X))
        #vareval = np.mean(1/sigmaeval, axis=0)  # M(SIGMA)
        vareval = np.std(grd0eval, axis=0)  # V(SIGMA (X-MU(X)))

        # grd0_var = np.std(grd0eval, axis=0)
        grd0m = np.mean(grd0eval, axis=0)  # [parfact,784]

        #grd0m = usc / np.abs(usc) * grd0m
        where_not_0 = np.where(usc > 0)
        div = usc
        div[where_not_0] = usc[where_not_0] / np.abs(usc)[where_not_0].astype(
            'float')

        grd0m = div * grd0m
        var0m = vareval

        return grd0m, var0m  # .astype(np.float64)

    def likelihood_grad_meth3(us):
        # inp: [parfact, ps*ps]
        # out: [parfact, ps*ps]
        usc = us.copy()
        usabs = np.abs(us)

        mueval = mu.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                    })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(mueval)
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(mueval.shape)
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(len(mueval))
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")

        stdeval = std.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                      })  # ,x_inp: np.tile(usabs,(nsampl,1))

        zvals = mueval + np.random.rand(mueval.shape[0],
                                        mueval.shape[1]) * stdeval

        y_outeval = y_out.eval(feed_dict={z: zvals})
        y_out_preceval = y_out_prec.eval(feed_dict={z: zvals})

        tmp = np.tile(usabs, (nsampl, 1)) - y_outeval
        tmp = (-1) * tmp * y_out_preceval

        # grd0eval: [500x784]
        grd0eval = np.array(np.split(tmp, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  # [parfact,784]

        where_not_0 = np.where(usc > 0)
        div = usc
        div[where_not_0] = usc[where_not_0] / np.abs(usc)[where_not_0]
        grd0m = div * grd0m

        return grd0m  # .astype(np.float64)

    def likelihood_grad_patches(ptchs):
        # inp: [np, ps, ps]
        # out: [np, ps, ps]
        # takes set of patches as input and returns a set of their grad.s
        # both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex64)

        grds_vars = grds.copy()

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            if usemeth == 1:
                grds[parfact * ix:parfact * ix +
                     parfact, :], grds_vars[parfact * ix:parfact * ix +
                                            parfact, :] = likelihood_grad(
                                                ptchs[parfact *
                                                      ix:parfact * ix +
                                                      parfact, :])
            else:
                assert (1 == 0)

        grds = grds[0:shape_orig[0], :]

        grds_vars = grds_vars[0:shape_orig[0], :]

        return np.reshape(grds, shape_orig), np.reshape(grds_vars, shape_orig)

    def likelihood_patches(ptchs):
        # inp: [np, ps, ps]
        # out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = likelihood(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        # inp: [nx*nx, 1]
        # out: [nx, ny], [nx, ny]

        # returns both gradients in the respective positive direction.
        # i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        grd_lik, grd_lik_var = likelihood_grad_patches(ptchs)
        grd_lik = (-1) * Ptchr.patches2im(grd_lik)
        grd_lik_var = Ptchr.patches2im(grd_lik_var)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_lik + grd_dconst, grd_lik, grd_dconst, grd_lik_var

    def full_funceval(image):
        # inp: [nx*nx, 1]
        # out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        lik = (-1) * likelihood_patches(np.abs(ptchs))

        return lik + dc, lik, dc

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):
        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):
            # grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            # thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

            # endfor

        return im - fdivg(us[:, :, :, it + 1])

    def g_tv_eval(x):
        x_re = np.fft.fftshift(np.reshape(x, (imsizer, imrizec, 1)),
                               axes=(0, 1))

        data = tf.placeholder(tf.float64, shape=x_re.shape)

        x_tv = tf.image.total_variation(data)
        var_grad = tf.gradients(x_tv, [data])[0]

        var_grad_val = var_grad.eval(feed_dict={data: x_re})

        return np.fft.ifftshift(var_grad_val, axes=(0, 1))

    def tv_norm(x):
        """Computes the total variation norm and its gradient. From jcjohnson/cnn-vis."""
        x = np.fft.fftshift(np.reshape(x, (imsizer, imrizec, 1)), axes=(0, 1))

        x_diff = x - np.roll(x, -1, axis=1)
        y_diff = x - np.roll(x, -1, axis=0)
        grad_norm2 = x_diff**2 + y_diff**2 + np.finfo(np.float32).eps
        norm = np.sum(np.sqrt(grad_norm2))
        dgrad_norm = 0.5 / np.sqrt(grad_norm2)
        dx_diff = 2 * x_diff * dgrad_norm
        dy_diff = 2 * y_diff * dgrad_norm
        grad = dx_diff + dy_diff
        grad[:, 1:, :] -= dx_diff[:, :-1, :]
        grad[1:, :, :] -= dy_diff[:-1, :, :]

        return norm, np.reshape(np.fft.ifftshift(grad, axes=(0, 1)), [-1])

    # make the data
    # ===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :, 0]
    data = us_ksp_r2

    trpat = np.zeros_like(uspat)
    trpat[:, 120:136] = 1

    # lrphase = np.angle( tUFT(data*trpat[:,:,np.newaxis],uspat) )
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lowresphase','rb'))
    # truephase = pickle.load(open('/home/ktezcan/unnecessary_stuff/truephase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/usphase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lrusphase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lrmaskphase','rb'))

    # make the functions for POCS
    # =====================================
    numiter = num_iter

    multip = 0  # 0.1

    alphas = stepsize * np.ones(numiter)  # np.logspace(-4,-4,numiter)

    #     alphas=np.ones_like(np.logspace(-4,-4,numiter))*5e-3

    def feval(im):
        return full_funceval(im)

    def geval(im):
        t1, t2, t3, t4 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(
            t3, [-1]), np.reshape(t4, [-1])

    # initialize data
    recs = np.zeros((imsizer * imrizec, numiter + 2), dtype=complex)

    #     recs[:,0] = np.abs(tUFT(data, uspat).flatten().copy()) #kct

    recs[:, 0] = tUFT(data, uspat).flatten().copy()

    #pickle.dump(recs[:, 0], open(logdir + '_rec_0', 'wb'))
    n4bf = 1

    #     recs[:,0] = np.abs(tUFT(data, uspat).flatten().copy() )*np.exp(1j*lrphase).flatten()

    phaseregvals = []

    # pickle.dump(recs[:,0],open('/scratc_','wb'))

    print('contRec is ' + contRec)
    if contRec != '':
        try:
            print('KCT-INFO: reading from a previous pickle file ' + contRec)
            import pickle
            rr = pickle.load(open(contRec, 'rb'))
            recs[:, 0] = rr[:, -1]
            print('KCT-INFO: initialized to the previous recon from pickle: ' +
                  contRec)
        except:
            print('KCT-INFO: reading from a previous numpy file ' + contRec)
            rr = np.load(contRec)
            recs[:, 0] = rr[:, -1]
            print('KCT-INFO: initialized to the previous recon from numpy: ' +
                  contRec)

    n4biasfields = []

    recsarr = []

    for it in range(0, numiter - 2, 2):
        alpha = alphas[it]

        # first do N times magnitude prior iterations
        # ===============================================
        # ===============================================

        recstmp = recs[:, it].copy()

        ftot, f_lik, f_dc = 0, 0, 0  #feval(recstmp)

        gtot, g_lik, g_dc, g_lik_var = geval(recstmp)

        tvnorm, tvgrad = tv_norm(np.abs(recstmp))

        lambda_lik = 0
        lambda_reg = 1
        recstmp_1 = recstmp - alpha * (lambda_lik * g_lik +
                                       lambda_reg * tvgrad)

        recs[:, it + 1] = recstmp_1.copy()

        print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_lik= " +
              str(f_lik) + ' TV norm= ' + str(tvnorm) + " f_dc (1e6)= " +
              str(f_dc / 1e6) + " |g_lik|= " + str(np.linalg.norm(g_lik)) +
              " |g_dc|= " + str(np.linalg.norm(g_dc)) + ' |g_tv|= ' +
              str(np.linalg.norm(tvgrad)))

        # if it == 0:
        #      pickle.dump(recstmp, open(logdir + '_rec_0', 'wb'))
        #      pickle.dump(g_lik, open(logdir + '_rec_likgrad', 'wb'))
        #      pickle.dump(g_dc, open(logdir + '_rec_dcgrad', 'wb'))
        #      pickle.dump(tvgrad, open(logdir + '_rec_tvgrad', 'wb'))
        #      pickle.dump(tvgrad * g_lik_var, open(logdir + '_rec_tvmulvar', 'wb'))
        #      pickle.dump(g_lik_var, open(logdir + '_rec_var', 'wb'))
        #      exit()
        # now do again a data consistency projection
        # ===============================================
        # ===============================================

        tmp1 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                   (1 - uspat))
        tmp2 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]), (uspat))
        tmp3 = data * uspat[:, :, np.newaxis]

        tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
        recs[:, it + 2] = tFT(tmp).flatten()

        #ftot, f_lik, f_dc = feval(recs[:, it + 2])
        #print('f_dc (1e6): ' + str(f_dc / 1e6) + '  perc: ' + str(100 * f_dc / np.linalg.norm(data) ** 2))

        # MSE CHECK
        recon_sli = np.reshape(recs[:, it + 2], (imsizer, imrizec))
        gt = np.reshape(gt, (imsizer, imrizec))

        rss = np.sqrt(
            np.sum(np.square(
                np.abs(sensmaps * np.tile(recon_sli[:, :, np.newaxis],
                                          [1, 1, sensmaps.shape[2]]))),
                   axis=-1))

        nmse = np.sqrt(((np.fft.fftshift(rss) - gt)**2).mean()) / np.sqrt(
            ((gt)**2).mean())
        print('NMSE: ', nmse)

    return recs, 0, phaseregvals, n4biasfields
Example #5
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode = False
#Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on S013.020
    patcher.nopout((0x8034a60))
    patcher.nopout((0x8034a60 + 0x2))
    patcher.nopout((0x8034a76))
    patcher.nopout((0x8034a76 + 0x2))
    patcher.nopout((0x8034a8c))
    patcher.nopout((0x8034a8c + 0x2))
    patcher.nopout((0x8034aa2))
    patcher.nopout((0x8034aa2 + 0x2))
    patcher.nopout((0x8034ab8))
    patcher.nopout((0x8034ab8 + 0x2))
    patcher.nopout((0x8034ace))
    patcher.nopout((0x8034ace + 0x2))
    patcher.nopout((0x8049f9a))
    patcher.nopout((0x8049f9a + 0x2))
    patcher.nopout((0x804a820))
Example #6
0
def vaerecon(us_ksp_r2,
             dcprojiter,
             onlydciter=0,
             lat_dim=60,
             patchsize=28,
             contRec='',
             lowresmodel=False,
             parfact=10,
             num_iter=1,
             regiter=15,
             reglmb=0.05,
             regtype='TV'):

    print('KCT-INFO: contRec is ' + contRec)
    print('KCT-INFO: parfact is ' + str(parfact))
    print("............", us_ksp_r2.shape)
    #print(sensmaps.shape)
    # set parameters
    #==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]  #252#256#252
    imrizec = us_ksp_r2.shape[1]  #308#256#308
    print(imsizer, imrizec)
    nsampl = 50

    #make a network and a patcher to use later
    #==============================================================================
    x_rec, x_inp, funop, grd0, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = definevae(
        lat_dim=lat_dim, patchsize=patchsize, batchsize=parfact * nsampl)

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)
    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    #define the necessary functions
    #==============================================================================

    def FT(x):
        # coil expansion followed by Fourier transform
        #inp: [nx, ny]
        #out: [nx, ny, ns]
        return np.fft.fftshift(np.fft.fft2(x[:, :, np.newaxis], axes=(0, 1)),
                               axes=(0, 1))

    def tFT(x):
        # inverse Fourier transform and coil combination
        #inp: [nx, ny, ns]
        #out: [nx, ny]

        temp = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1))
        return np.sum(temp, axis=2)

    def UFT(x, uspat):
        # Encoding: undersampling +  FT
        #inp: [nx, ny], [nx, ny]
        #out: [nx, ny, ns]

        return uspat[:, :, np.newaxis] * FT(x)

    def tUFT(x, uspat):
        # transposed Encoding: inverse FT + undersampling
        #inp: [nx, ny], [nx, ny]
        #out: [nx, ny]

        tmp1 = uspat[:, :, np.newaxis]

        return tFT(tmp1 * x)

    def dconst(us):
        #inp: [nx, ny]
        #out: [nx, ny]

        return np.linalg.norm(UFT(us, uspat) - data)**2
        #np.linalg.norm ji suan l2 fanshu
    def dconst_grad(us):
        #inp: [nx, ny]
        #out: [nx, ny]
        return 2 * tUFT(UFT(us, uspat) - data, uspat)

    def prior(us):
        #inp: [parfact,ps*ps]
        #out: parfact

        us = np.abs(us)
        funeval = funop.eval(feed_dict={x_rec: np.tile(us, (nsampl, 1))})  #
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def prior_grad(us):
        #inp: [parfact, ps*ps]
        #out: [parfact, ps*ps]

        usc = us.copy()
        usabs = np.abs(us)

        grd0eval = grd0.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                        })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  #[parfact,784]
        print("11111111111111111111111111111111")
        print(np.abs(usc))
        print("22222222222222222222222222222222")
        print(grd0m)
        grd0m = usc / np.abs(usc) * grd0m

        return grd0m  #.astype(np.float64)

    def prior_grad_patches(ptchs):
        #inp: [np, ps, ps]
        #out: [np, ps, ps]
        #takes set of patches as input and returns a set of their grad.s
        #both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex64)

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            grds[parfact * ix:parfact * ix + parfact, :] = prior_grad(
                ptchs[parfact * ix:parfact * ix + parfact, :])

        grds = grds[0:shape_orig[0], :]

        return np.reshape(grds, shape_orig)
        #np.ceil  xiang shang quzhengshu

    def prior_patches(ptchs):
        #inp: [np, ps, ps]
        #out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = prior(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        #inp: [nx*nx, 1]
        #out: [nx, ny], [nx, ny]

        #returns both gradients in the respective positive direction.
        #i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        grd_prior = prior_grad_patches(ptchs)
        grd_prior = (-1) * Ptchr.patches2im(grd_prior)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_prior + grd_dconst, grd_prior, grd_dconst

    def full_funceval(image):
        #inp: [nx*nx, 1]
        #out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        priorval = (-1) * prior_patches(np.abs(ptchs))

        return priorval + dc, priorval, dc

    #define the phase regularization functions
    #==============================================================================

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        # Total variation based projection

        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        # gradient operation with 1st order finite differences
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        # divergence operator with 1st order finite differences
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):
        # soft thresholding

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):

        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):

            #grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            #thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

        #endfor

        return im - fdivg(us[:, :, :, it + 1])

    def reg2_proj(usph, niter=100, alpha=0.05):
        #A smoothness based based projection. Regularization method 2 from
        #"Separate Magnitude and Phase Regularization via Compressed Sensing",  Feng Zhao et al, IEEE TMI, 2012

        usph = usph + np.pi

        ims = np.zeros((imsizer, imrizec, niter))
        ims[:, :, 0] = usph.copy()
        for ix in range(niter - 1):
            ims[:, :, ix + 1] = ims[:, :, ix] - 2 * alpha * np.real(
                1j * np.exp(-1j * ims[:, :, ix]) *
                fdivg(fgrad(np.exp(1j * ims[:, :, ix]))))

        return ims[:, :, -1] - np.pi

    #make the data
    #===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :, 0]
    data = us_ksp_r2

    import pickle

    #make the functions for POCS
    #=====================================
    #number of iterations
    numiter = num_iter

    # if you want to do an affine data consistency projection
    multip = 0  # 0 means simply replacing the measured values

    # step size for the prior iterations
    alphas = np.logspace(-4, -4, numiter)

    # some funtions for simpler coding
    def feval(im):
        print('im.dtype..............', im.dtype)
        return full_funceval(im)

    def geval(im):
        t1, t2, t3 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(t3, [-1])

    # initialize data with the zero-filled image
    recs = np.zeros((imsizer * imrizec, numiter), dtype=complex)
    recs[:, 0] = tUFT(data, uspat).flatten().copy()

    # if you want to instead continue reconstruction from an existing image
    print(' KCT-INFO: contRec is ' + contRec)
    if contRec != '':
        print('KCT-INFO: reading from a previous file ' + contRec)
        rr = pickle.load(open(contRec, 'rb'))
        recs[:, 0] = rr[:, -1]
        print('KCT-INFO: initialized to the previous recon: ' + contRec)

    import time
    # the itertaion loop
    for it in range(numiter - 1):
        start = time.time()
        # get the step size for the Piteration
        alpha = alphas[it]

        # if you want to do some data consistency iterations before starting the prior projections
        # this can be helpful when doing recon with multiple coils, e.g. you can do pure SENSE in the beginning...
        # or only do phase projections in the beginning.
        if it > onlydciter:

            # get the gradients for the prior projection and the likelihood values
            ftot, f_prior, f_dc = feval(recs[:, it])
            gtot, g_prior, g_dc = geval(recs[:, it])

            #                print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_prior= " + str(-f_prior) + " f_dc (x1e6)= " + str(f_dc/1e6) + " |g_prior|= " + str(np.linalg.norm(g_prior)) + " |g_dc|= " + str(np.linalg.norm(g_dc)) )
            print(
                "it no: {0}, f_tot= {1:.2f},f_prior= {2:.2f}, f_dc (x1e6)= {3:.2f}, |g_prior|= {4:.2f}, |g_dc|= {5:.2f}"
                .format(it, ftot, f_prior, f_dc / 1e6, np.linalg.norm(g_prior),
                        np.linalg.norm(g_dc)))

            # update the image with the prior gradient
            recs[:, it + 1] = recs[:, it] - alpha * g_prior
            print("............................................")
            print(recs[:, it + 1])
            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)

        else:  # the case where you do only data consistency iterations (also iteration 0)
            if not it == 0:
                print(
                    'KCT-info: skipping prior proj for the first onlydciters iter.s, doing only phase proj (then maybe DC proj as well) !!!'
                )

            recs[:, it + 1] = recs[:, it].copy()

            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)
            print(recs.shape)
        #do the DC projection every 'dcprojiter' iterations
        if it < onlydciter + 1 or it % dcprojiter == 0:  #

            tmp1 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                       (1 - uspat))
            tmp2 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                       (uspat))
            tmp3 = data * uspat[:, :, np.newaxis]

            #combine the measured data with the projected image affinely (multip=0 for replacing the values)
            tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
            recs[:, it + 1] = tFT(tmp).flatten()

            ftot, f_lik, f_dc = feval(recs[:, it + 1])
            print(ftot.shape, f_lik.shape, f_dc.shape)
        end = time.time()

        print('one iteration time is', end - start)
    return recs  #YourDatasetModuleHere
Example #7
0
import re
import subprocess
from Patcher import Patcher


def _get_output(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT)
    (stdout, stderr) = proc.communicate()
    stdout = stdout.decode().rstrip("\r\n").lstrip("\r\n")
    return stdout


os.chdir("..")
patcher = Patcher("README.md", filetype="markdown")

stdout = _get_output(["./x509sak.py"])
text = "\n```\n$ ./x509sak.py\n%s\n```\n" % (stdout)
patcher.patch("summary", text)

commands = []
command_re = re.compile("Begin of cmd-(?P<cmdname>[a-z]+)")
for match in command_re.finditer(patcher.read()):
    cmdname = match.groupdict()["cmdname"]
    commands.append(cmdname)

for command in commands:
    stdout = _get_output(["./x509sak.py", command, "--help"])
    text = "\n```\n%s\n```\n" % (stdout)
    patcher.patch("cmd-%s" % (command), text)
Example #8
0
    arg_enum.append("	%s = '%s'," % (enum_name, optchar))
arg_enum.append("")
first = True
for (option, enum_name, requires_parameter) in opts_long:
    if first:
        arg_enum.append("	%s = 1000," % (enum_name))
        first = False
    else:
        arg_enum.append("	%s," % (enum_name))
arg_enum.append("};")
arg_enum = "\n".join(arg_enum) + "\n"

cmd_def = ["	const char *short_options = \"%s\";" % (short_string)]
cmd_def.append("	struct option long_options[] = {")
for (option, enum_name, requires_parameter) in opts_long:
    param = "\"%s\"," % (option)
    if requires_parameter:
        cmd_def.append("		{ %-30s required_argument, 0, %s }," %
                       (param, enum_name))
    else:
        cmd_def.append("		{ %-30s no_argument,       0, %s }," %
                       (param, enum_name))
cmd_def.append("		{ 0 }")
cmd_def.append("	};")
cmd_def = "\n".join(cmd_def) + "\n"

patcher = Patcher("../pgmopts.c")
patcher.patch("help page", help_code)
patcher.patch("command definition enum", arg_enum)
patcher.patch("command definition", cmd_def)
Example #9
0
def vaerecon(us_ksp_r2,
             uspat,
             orim,
             i,
             result_all,
             method,
             dcprojiter,
             onlydciter=0,
             lat_dim=60,
             patchsize=28,
             contRec='',
             lowresmodel=False,
             parfact=10,
             num_iter=1,
             regiter=15,
             reglmb=0.05,
             regtype='TV'):
    def write_Data(result_all):
        with open(os.path.join(savepath, "psnr_" + method + ".txt"),
                  "w+") as f:
            #print(len(result_all))
            for i in range(len(result_all)):
                f.writelines('current image {} PSNR : '.format(i) + str(result_all[i][0]) + \
                "    SSIM : " + str(result_all[i][1]) + "    HFEN : " + str(result_all[i][2]))
                f.write('\n')

    print('KCT-INFO: contRec is ' + contRec)
    print('KCT-INFO: parfact is ' + str(parfact))
    #print("............",us_ksp_r2.shape)
    #print(sensmaps.shape)
    # set parameters
    #==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]  #252#256#252
    imrizec = us_ksp_r2.shape[1]  #308#256#308
    #print(imsizer,imrizec)
    nsampl = 50

    #make a network and a patcher to use later
    #==============================================================================
    x_rec, x_inp, funop, grd0, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = definevae(
        lat_dim=lat_dim, patchsize=patchsize, batchsize=parfact * nsampl)

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)
    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    def dconst(us):
        #inp: [nx, ny]
        #out: [nx, ny]

        return np.linalg.norm(np.fft.fft2(us) * uspat - data)**2

    def dconst_grad(us):
        #inp: [nx, ny]
        #out: [nx, ny]
        return 2 * np.fft.ifft2(uspat * (np.fft.fft2(us) * uspat - data))

    def prior(us):
        #inp: [parfact,ps*ps]
        #out: parfact

        us = np.abs(us)
        funeval = funop.eval(feed_dict={x_rec: np.tile(us, (nsampl, 1))})  #
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def prior_grad(us):
        #inp: [parfact, ps*ps]
        #out: [parfact, ps*ps]
        #print('prior_grad',us.dtype,np.max(us.imag))
        usc = us.copy()
        usabs = np.abs(us)
        #print('prior_grad usabs',usabs.dtype,np.max(usabs.imag))
        #print('prior_grad usc',usc.dtype,np.max(usc.imag))
        #print('usc == us ? :',usc == us)
        grd0eval = grd0.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                        })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  #[parfact,784]

        #print('grd0eval.shape,grd0m.shape',grd0eval.shape,grd0m.shape)
        #print(grd0m)

        grd0m = usc / np.abs(usc) * grd0m

        #print('prior_grad grd0m',grd0m.dtype,np.max(grd0m.imag))

        return grd0m  # .astype(np.float64)

    def prior_grad_patches(ptchs):
        #inp: [np, ps, ps]
        #out: [np, ps, ps]
        #takes set of patches as input and returns a set of their grad.s
        #both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex128)

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        #print('prior_grad_patches',ptchs.dtype,np.max(ptchs.imag))
        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            grds[parfact * ix:parfact * ix + parfact, :] = prior_grad(
                ptchs[parfact * ix:parfact * ix + parfact, :])

        grds = grds[0:shape_orig[0], :]

        print('np.reshape(grds, shape_orig).shape',
              np.reshape(grds, shape_orig).shape)
        return np.reshape(grds, shape_orig)
        #np.ceil   shang qu zheng

    def prior_patches(ptchs):
        #inp: [np, ps, ps]
        #out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = prior(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        #inp: [nx*nx, 1]
        #out: [nx, ny], [nx, ny]

        #returns both gradients in the respective positive direction.
        #i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)
        #print('ptchs dtype',ptchs.dtype,np.max(ptchs))

        grd_prior = prior_grad_patches(ptchs)
        grd_prior = (-1) * Ptchr.patches2im(grd_prior)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_prior + grd_dconst, grd_prior, grd_dconst

    def full_funceval(image):
        #inp: [nx*nx, 1]
        #out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        priorval = (-1) * prior_patches(np.abs(ptchs))

        return priorval + dc, priorval, dc

    #define the phase regularization functions
    #==============================================================================

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        # Total variation based projection

        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        # gradient operation with 1st order finite differences
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        # divergence operator with 1st order finite differences
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):
        # soft thresholding

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):

        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):

            #grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            #thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

        #endfor

        return im - fdivg(us[:, :, :, IT - 1])

    def reg2_proj(usph, niter=100, alpha=0.05):
        #A smoothness based based projection. Regularization method 2 from
        #"Separate Magnitude and Phase Regularization via Compressed Sensing",  Feng Zhao et al, IEEE TMI, 2012

        usph = usph + np.pi * 3 / 2

        ims = np.zeros((imsizer, imrizec, niter))
        ims[:, :, 0] = usph.copy()
        for ix in range(niter - 1):
            ims[:, :, ix + 1] = ims[:, :, ix] - 2 * alpha * np.real(
                1j * np.exp(-1j * ims[:, :, ix]) *
                fdivg(fgrad(np.exp(1j * ims[:, :, ix]))))

        return ims[:, :, -1] - np.pi * 3 / 2

    #make the data
    #===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :]
    data = us_ksp_r2

    import pickle

    #make the functions for POCS
    #=====================================
    #number of iterations
    numiter = num_iter

    # if you want to do an affine data consistency projection
    multip = 0  # 0 means simply replacing the measured values

    # step size for the prior iterations
    alphas = np.logspace(-4, -4, numiter)

    #print('step size for the prior iterations',alphas)
    #assert False

    def feval(im):
        return full_funceval(im)

    def geval(im):
        t1, t2, t3 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(t3, [-1])

    # initialize data with the zero-filled image
    recs = np.zeros((imsizer * imrizec, numiter), dtype=complex)
    #recs[:,0] = tUFT(data, uspat).flatten().copy()

    recs[:, 0] = np.fft.ifft2(data * uspat).flatten().copy()
    recs[:, 0] = abs(recs[:, 0])
    print(
        'Zerosfilled Psnr :',
        compare_psnr(255 * np.abs(np.reshape(recs[:, 0], [imsizer, imrizec])),
                     255 * np.abs(orim),
                     data_range=255))
    #plt.figure(444444444)
    #plt.imshow(np.abs(np.reshape(recs[:,0],[imsizer,imrizec])),cmap='gray')
    #plt.show()
    # if you want to instead continue reconstruction from an existing image
    print(' KCT-INFO: contRec is ' + contRec)
    if contRec != '':
        print('KCT-INFO: reading from a previous file ' + contRec)
        rr = pickle.load(open(contRec, 'rb'))
        recs[:, 0] = rr[:, -1]
        print('KCT-INFO: initialized to the previous recon: ' + contRec)

    import time
    # the itertaion loop
    max_psnr = 0
    max_ssim = 0
    min_hfen = 100
    for it in range(numiter - 1):
        start = time.time()
        # get the step size for the Piteration
        alpha = alphas[it]

        # if you want to do some data consistency iterations before starting the prior projections
        # this can be helpful when doing recon with multiple coils, e.g. you can do pure SENSE in the beginning...
        # or only do phase projections in the beginning.
        if it > onlydciter:

            # get the gradients for the prior projection and the likelihood values
            #ftot, f_prior, f_dc = feval(recs[:,it])
            print('recs dtype Input ', recs.dtype)
            gtot, g_prior, g_dc = geval(recs[:, it])

            #                print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_prior= " + str(-f_prior) + " f_dc (x1e6)= " + str(f_dc/1e6) + " |g_prior|= " + str(np.linalg.norm(g_prior)) + " |g_dc|= " + str(np.linalg.norm(g_dc)) )
            #print("it no: {0}, f_tot= {1:.2f},f_prior= {2:.2f}, f_dc (x1e6)= {3:.2f}, |g_prior|= {4:.2f}, |g_dc|= {5:.2f}".format(it,ftot,f_prior,f_dc/1e6,np.linalg.norm(g_prior),np.linalg.norm(g_dc)))

            # update the image with the prior gradient
            recs[:, it + 1] = recs[:, it] - alpha * g_prior

            print(
                'current Phase prePsnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current Phase Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Phase Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #plt.figure(1111)
            #plt.imshow(np.abs(np.reshape(recs[:,it+1],[imsizer,imrizec])),cmap='gray')
            #plt.show()
            #print("............................................")
            #print(recs[:,it+1])
            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1],
                                       [imsizer, imrizec]))  #+(np.pi*3/2)
            print(np.max(tmpp), np.min(tmpp))
            tmpaf = tmpa.copy().flatten()

            #plt.figure(1111)
            #plt.imshow(tmpp,cmap='gray')
            #plt.show()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            #tmpptv -= (np.pi*3/2)
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)  #####here

        else:  # the case where you do only data consistency iterations (also iteration 0)
            if not it == 0:
                print(
                    'KCT-info: skipping prior proj for the first onlydciters iter.s, doing only phase proj (then maybe DC proj as well) !!!'
                )

            recs[:, it + 1] = recs[:, it].copy()

            print(
                'current Phase prePsnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current Phase Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Phase Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)
            #plt.figure(22222)
            #plt.imshow(np.abs(np.reshape(recs[:,it+1],[imsizer,imrizec])),cmap='gray')
            #plt.show()
        #do the DC projection every 'dcprojiter' iterations
        if it < onlydciter + 1 or it % dcprojiter == 0:  #

            tmp_rec = np.reshape(recs[:, it + 1], [imsizer, imrizec])
            #plt.figure(3333333)
            #plt.imshow(np.abs(tmp1),cmap='gray')
            #plt.show()
            print(np.max(np.abs(tmp_rec)), np.max(255 * np.abs(orim)))
            print(
                'current Prior Psnr :',
                compare_psnr(255 * np.abs(tmp_rec),
                             255 * np.abs(orim),
                             data_range=255))
            print(
                'current Prior Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Prior Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #tmp1 = np.fft.fft2(tmp_rec)
            #print(uspat)
            #tmp1[uspat==1] = us_ksp_r2[uspat==1]
            #tmp1 = np.fft.ifft2(tmp1)
            #recs[:,it+1] = tmp1.flatten()
            #plt.figure(1)
            #plt.imshow(np.abs(tmp1),cmap='gray')
            #plt.show()

            tmp1 = np.fft.fft2(np.reshape(recs[:, it + 1],
                                          [imsizer, imrizec])) * (1 - uspat)
            tmp2 = np.fft.fft2(np.reshape(recs[:, it + 1],
                                          [imsizer, imrizec])) * (uspat)

            tmp3 = data * uspat  #[:,:,np.newaxis]

            #combine the measured data with the projected image affinely (multip=0 for replacing the values)
            tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
            recs[:, it + 1] = np.fft.ifft2(tmp).flatten()  #np.abs(
            print('recs dtype DC ', recs.dtype)
            print(
                'current DC Psnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current DC Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current DC Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #ftot, f_lik, f_dc = feval(recs[:,it+1])

            #print(ftot.shape,f_lik.shape,f_dc.shape)
        x_complex = np.reshape(recs[:, it + 1], [256, 256])
        psnr = compare_psnr(255 * abs(x_complex),
                            255 * abs(orim),
                            data_range=255)
        ssim = compare_ssim(abs(x_complex), abs(orim), data_range=1)
        hfen = compare_hfen(abs(x_complex), abs(orim))
        print('current %d image %d iteration PSNR:' % (i, it), psnr, ' SSIM :',
              ssim, ' HFEN :', hfen)
        if max_psnr < psnr:
            result_all[i, 0] = psnr
            max_psnr = psnr
            result_all[31, 0] = sum(result_all[:31, 0]) / 31
            savemat(
                os.path.join(savepath,
                             'img_{}_Rec_'.format(i) + method + '.mat'),
                {
                    'data':
                    np.array(np.reshape(recs[:, it + 1], [256, 256]),
                             dtype=np.complex)
                })
            cv2.imwrite(
                os.path.join(savepath,
                             'img_{}_Rec_'.format(i) + method + '.png'),
                np.array(255 * np.abs(np.reshape(recs[:, it + 1], [256, 256])),
                         dtype=np.uint8))
        if max_ssim < ssim:
            result_all[i, 1] = ssim
            max_ssim = ssim
            result_all[31, 1] = sum(result_all[:31, 1]) / 31

        if min_hfen > hfen:
            result_all[i, 2] = hfen
            min_hfen = hfen
            result_all[31, 2] = sum(result_all[:31, 2]) / 31
        write_Data(result_all)
        #print('recs dtype DC ',recs.dtype)
        end = time.time()
        print('{} iteration time is'.format(it), end - start)
    return recs  #YourDatasetModuleHere
def vaerecon(us_ksp_r2, # undersampled k space
             sensmaps,
             dcprojiter,
             onlydciter = 0,
             lat_dim = 60,
             patchsize = 28,
             contRec = '',
             parfact = 10,
             num_iter = 302,
             rescaled = False,
             half = False,
             regiter = 15,
             reglmb = 0.05,
             regtype = 'TV',
             usemeth = 1,
             stepsize = 1e-4,
             optScale = False,
             mode = [],
             chunks40 = False,
             Melmodels = '',
             N4BFcorr = False,
             z_multip = 1.0,
             n1 = 5,
             n2 = 5,
             log_dir = ''):
     
     logging.info('xxxxxxxxxxxxxxxxxxx contRec is ' + contRec)
     logging.info('xxxxxxxxxxxxxxxxxxx parfact is ' + str(parfact) )
     
     # ==============================================================================
     # set parameters
     # ==============================================================================
     np.random.seed(seed = 1)     
     imsizer = us_ksp_r2.shape[0] #252#256#252
     imrizec = us_ksp_r2.shape[1] #308#256#308
     nsampl = 50 #0
          
     # ==============================================================================
     # make a network and a patcher to use later
     # ==============================================================================     
     # =================================
     # get the output from the VAE
     # funop: ELBO
     # grd0: gradient of the ELBO wrt the image
     # grd_p_x_z0, grd_p_z0, grd_q_z_x0: different gradients
     # =================================
     vae_outputs = definevae(lat_dim = lat_dim,
                             patchsize = patchsize,
                             batchsize = parfact*nsampl,
                             rescaled = rescaled,
                             half = half,
                             mode = mode,
                             chunks40 = chunks40,
                             Melmodels = Melmodels,
                             use_normalizer = bool(n1>0),
                             log_dir = log_dir)
     
     x_rec, x_inp, funop, grd0, sess = vae_outputs[0:5]
     grd_p_x_z0, grd_p_z0, grd_q_z_x0, grd20 = vae_outputs[5:9] # these are not being used in the code now.
     y_out, y_out_prec, z_std_multip, op_q_z_x = vae_outputs[9:13] # these are not being used in the code now.
     mu, std = vae_outputs[13:15] # these are not being used in the code now.
     grd_q_zpl_x_az0, op_q_zpl_x, z_pl, z = vae_outputs[15:19] # these are not being used in the code now.
     # Most of these outputs were being used in the function likelihood_grad_meth3, which is no longer being used.
     norm_accum_grads_zero_op, norm_accum_grads_op, norm_accum_grads_mean_op, num_accum_steps_pl, norm_update_op, x_norm = vae_outputs[19:25]
     f_elbo, f_data_consistency, f_total, f_summary = vae_outputs[25:29]
     g_elbo, g_data_consistency, g_total, g_summary, summary_writer = vae_outputs[29:34]
     
     # =================================
     # used to go from image to patches and back
     # =================================
     Ptchr = Patcher(imsize = [imsizer, imrizec],
                     patchsize = patchsize,
                     step = int(patchsize/2),
                     nopartials = True,
                     contatedges = True)
     
     nopatches = len(Ptchr.genpatchsizes)
     logging.info("There will be in total " + str(nopatches) + " patches.")
     
     # =================================
     # functions for data consistency
     # =================================
     def dconst(us):
          #inp: [nx, ny]
          #out: [nx, ny]
          return np.linalg.norm(utils.UFT_with_sensmaps(us, uspat, sensmaps) - data)**2
     
     def dconst_grad(us):
          #inp: [nx, ny]
          #out: [nx, ny]
          return 2 * utils.tUFT_with_sensmaps(utils.UFT_with_sensmaps(us, uspat, sensmaps) - data, uspat, sensmaps)

     # =================================  
     # function for running the normalization op on a batch of patches
     # =================================  
     def update_normalizer(im,
                           sess,
                           accum_gradients_zero_op,
                           accum_gradients_op,
                           accum_gradients_mean_op,
                           num_accum_steps_pl,
                           update_normalizer_op):
         
          # make patches
          ptchs = Ptchr.im2patches(np.reshape(im, [imsizer, imrizec]))
          ptchs = np.array(ptchs)          
          ptchs = np.abs(ptchs)
          
          # zero accumulated gradients
          sess.run(accum_gradients_zero_op)
          num_accumulation_steps = 0
          
          # convert image to patches
          ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])
          # extra indices that have to be padded to ensure we have complete batches for the VAE
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, ((0, extraind), (0,0)), mode = 'edge')
          
          # send batches of patches and accumulate gradients
          for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
              batch_of_patches = ptchs[parfact * ix : parfact * ix + parfact, :]
              sess.run(accum_gradients_op, feed_dict = {x_rec: np.tile(batch_of_patches, (nsampl, 1)), z_std_multip: z_multip})
              num_accumulation_steps = num_accumulation_steps + 1
              
          # run op to mean gradients accumulated so far
          sess.run(accum_gradients_mean_op, feed_dict = {num_accum_steps_pl: num_accumulation_steps})
          
          # update normalizer parameters according to the mean gradient.
          # The stuff passed in the feed_dict does not matter in this line, but something needs to be passed. So passing the last batch of patches.
          sess.run(update_normalizer_op, feed_dict = {x_rec: np.tile(batch_of_patches, (nsampl, 1)), z_std_multip: z_multip})
          
          return 0     
     
     # =================================
     # functions for computing the ELBO and its derivatives
     # =================================     
     
     # =================================     
     # returns the elbo of a batch of patches
     # =================================     
     def likelihood(us):
          # inp: [parfact,ps*ps]
          # out: parfact
          us = np.abs(us)
          funeval = funop.eval(feed_dict = {x_rec: np.tile(us, (nsampl,1)), z_std_multip: z_multip})
          funeval = np.array(np.split(funeval, nsampl, axis=0)) # [nsampl x parfact x 1]
          return np.mean(funeval, axis=0).astype(np.float64)
      
     # =================================    
     # returns the elbo of the input patches
     # =================================    
     def likelihood_patches(ptchs):
          # inp: [np, ps, ps] 
          # out: 1
          
          fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact) ])
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, [(0,extraind), (0,0), (0,0)], mode='edge')
          
          for ix in range(int(np.ceil(ptchs.shape[0]/parfact))):
               fvls[parfact*ix : parfact*ix+parfact] = likelihood(np.reshape(ptchs[parfact*ix : parfact*ix+parfact,:,:], [parfact,-1]))
               
          fvls = fvls[0:ptchs.shape[0]]
               
          return np.mean(fvls)
      
     # =================================
     # returns the ELBO of the image 
     # =================================     
     def full_funceval(image):
          #inp: [nx*nx, 1]
          #out: [1], [1], [1]
          
          tmpimg = np.reshape(image, [imsizer,imrizec])
          
          # how consistent is this image with the measured undersampled k-space data    
          dc = dconst(tmpimg) 
          
          # convert the image into patches and measure the elbo of each patch
          ptchs = Ptchr.im2patches(np.reshape(image, [imsizer,imrizec]))
          ptchs = np.array(ptchs)          
          lik = (-1)*likelihood_patches(np.abs(ptchs))    
          
          return lik + dc, lik, dc    
     
     # =================================  
     # returns the gradient of the elbo of a batch of patches wrt the batch of patches
     # =================================  
     def likelihood_grad(us):
          # inp: [parfact, ps*ps]
          # out: [parfact, ps*ps]
          usc = us.copy()
          usabs = np.abs(us)
          grd0eval = grd0.eval(feed_dict = {x_rec: np.tile(usabs, (nsampl, 1)), z_std_multip: z_multip}) # grd0eval: [500x784]
          grd0eval = np.array(np.split(grd0eval, nsampl, axis=0)) # [nsampl x parfact x 784]
          grd0m = np.mean(grd0eval,axis=0) # [parfact,784]
          # correction for the complex image
          grd0m = usc/np.abs(usc)*grd0m
          return grd0m #.astype(np.float64)
     
     # =================================  
     # returns the gradient of the elbo of the input patches wrt the input patches
     # =================================  
     def likelihood_grad_patches(ptchs):
          # inp: [np, ps, ps] 
          # out: [np, ps, ps] 
          # takes set of patches as input and returns a set of their grad.s 
          # both grads are in the positive direction
          
          shape_orig = ptchs.shape
          ptchs = np.reshape(ptchs, [ptchs.shape[0], -1] )
          grds = np.zeros([int(np.ceil(ptchs.shape[0]/parfact)*parfact), np.prod(ptchs.shape[1:])], dtype = np.complex64)
          
          # extra indices that have to be padded to ensure we have complete batches for the VAE
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, ((0, extraind), (0,0)), mode='edge')
          
          for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
              grds[parfact*ix : parfact*ix+parfact, :] = likelihood_grad(ptchs[parfact*ix : parfact*ix+parfact, :]) 

          grds = grds[0:shape_orig[0],:]

          return np.reshape(grds, shape_orig)
     

     # =================================
     # returns the gradient of the ELBO wrt the image 
     # =================================     
     def full_gradient(image):
          #inp: [nx*nx, 1]
          #out: [nx, ny], [nx, ny]
          #returns both gradients in the respective positive direction.
          #i.e. must 
          
          # convert image to patches
          ptchs = Ptchr.im2patches(np.reshape(image, [imsizer,imrizec]))
          ptchs = np.array(ptchs)          
          
          # compute gradient of the ELBO wrt each patch
          grd_lik = likelihood_grad_patches(ptchs)
          grd_lik = (-1)* Ptchr.patches2im(grd_lik)
          
          grd_dconst = dconst_grad(np.reshape(image, [imsizer,imrizec]))
          
          return grd_lik + grd_dconst, grd_lik, grd_dconst
     
     # =================================
     # phase projection functions
     # =================================     
     def tv_proj(phs,mu=0.125,lmb=2,IT=225):
          phs = fb_tv_proj(phs,mu=mu,lmb=lmb,IT=IT)          
          return phs
     
     def fgrad(im):
          imr_x = np.roll(im,shift=-1,axis=0)
          imr_y = np.roll(im,shift=-1,axis=1)
          grd_x = imr_x - im
          grd_y = imr_y - im          
          return np.array((grd_x, grd_y))
     
     def fdivg(im):
          imr_x = np.roll(np.squeeze(im[0,:,:]),shift=1,axis=0)
          imr_y = np.roll(np.squeeze(im[1,:,:]),shift=1,axis=1)
          grd_x = np.squeeze(im[0,:,:]) - imr_x
          grd_y = np.squeeze(im[1,:,:]) - imr_y          
          return grd_x + grd_y
     
     def f_st(u,lmb):          
          uabs = np.squeeze(np.sqrt(np.sum(u*np.conjugate(u),axis=0)))          
          tmp = 1 - lmb/uabs
          tmp[np.abs(tmp) < 0] = 0             
          uu = u*np.tile(tmp[np.newaxis,:,:],[u.shape[0],1,1])          
          return uu
       
     def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):
          sz = im.shape
          us=np.zeros((2,sz[0],sz[1],IT))
          us[:,:,:,0] = u0          
          for it in range(IT-1):               
               # grad descent step:
               tmp1 = im - fdivg(us[:,:,:,it])
               tmp2 = mu*fgrad(tmp1)
               tmp3 = us[:,:,:,it] - tmp2                 
               # thresholding step:
               us[:,:,:,it+1] = tmp3 - f_st(tmp3, lmb=lmb)                    
          return im - fdivg(us[:,:,:,it+1]) 
          
     def tikh_proj(usph, niter=100, alpha=0.05):          
          ims = np.zeros((imsizer, imrizec, niter))
          ims[:,:,0] = usph.copy()
          for ix in range(niter-1):
              ims[:,:,ix+1] = ims[:,:,ix] + alpha*2*fdivg(fgrad(ims[:,:,ix]))
          return ims[:,:,-1]
     
     def reg2_proj(usph, niter=100, alpha=0.05):          
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          usph = usph + np.pi
          ims = np.zeros((imsizer,imrizec,niter))
          ims[:,:,0]=usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())

          for ix in range(niter-1):
              ims[:,:,ix+1] = ims[:,:,ix] +alpha*reg2grd(ims[:,:,ix].flatten()).reshape([252,308]) # *alpha*np.real(1j*np.exp(-1j*ims[:,:,ix])*    fdivg(fgrad(np.exp(  1j* ims[:,:,ix]    )))     )
              regval = reg2eval(ims[:,:,ix+1].flatten())
          
          return ims[:,:,-1] - np.pi    
     
     def reg2_dcproj(usph, magim, bfestim, niter=100, alpha_reg=0.05, alpha_dc=0.05):
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          # usph = usph+np.pi          
          ims = np.zeros((imsizer,imrizec,niter))
          grds_reg = np.zeros((imsizer,imrizec,niter))
          grds_dc = np.zeros((imsizer,imrizec,niter))
          ims[:,:,0]=usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())
          
          for ix in range(niter-1):
               
              grd_reg = reg2grd(ims[:,:,ix].flatten()).reshape([252,308])  # *alpha*np.real(1j*np.exp(-1j*ims[:,:,ix])*    fdivg(fgrad(np.exp(  1j* ims[:,:,ix]    )))     )
              grds_reg[:,:,ix]  = grd_reg
              grd_dc = reg2_dcgrd(ims[:,:,ix].flatten() , magim, bfestim).reshape([252,308])
              grds_dc[:,:,ix]  = grd_dc
              
              ims[:,:,ix+1] = ims[:,:,ix] + alpha_reg*grd_reg  - alpha_dc*grd_dc
              regval = reg2eval(ims[:,:,ix+1].flatten())
              f_dc = dconst(magim*np.exp(1j*ims[:,:,ix+1])*bfestim)
              
              print_info = False
              if print_info is True:
                  logging.info("norm grad reg: " + str(np.linalg.norm(grd_reg)))
                  logging.info("norm grad dc: " + str(np.linalg.norm(grd_dc)))
                  logging.info("regval: " + str(regval))
                  logging.info("fdc: (*1e9) {0:.6f}".format(f_dc/1e9))
          
          return ims[:,:,-1] #-np.pi    
     
     def reg2eval(im):
          # takes in 1d, returns scalar
          im = im.reshape([252,308])
          phs = np.exp(1j*im)
          return np.linalg.norm(fgrad(phs).flatten())
     
     def reg2grd(im):
          # takes in 1d, returns 1d
          im = im.reshape([252,308])
          return -2*np.real(1j*np.exp(-1j*im) * fdivg(fgrad(np.exp(1j * im)))).flatten()
     
     def reg2_dcgrd(phim, magim, bfestim):   
          # takes in 1d, returns 1d
          phim = phim.reshape([252,308])
          magim = magim.reshape([252,308])
          tmp = utils.UFT_with_sensmaps(bfestim * np.exp(1j * phim) * magim, uspat, sensmaps) - data
          return -2 * np.real(1j * np.exp(-1j*phim) * magim * bfestim * utils.tUFT_with_sensmaps(tmp, uspat, sensmaps)).flatten()
     
     def reg2_proj_ls(usph, niter=100):
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          # with line search         
          usph = usph + np.pi
          ims = np.zeros((imsizer, imrizec, niter))
          ims[:,:,0] = usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())
          logging.info(regval)
          for ix in range(niter-1):               
              currgrd = reg2grd(ims[:,:,ix].flatten())     
              res = sop.minimize_scalar(lambda alpha: reg2eval(ims[:,:,ix].flatten() + alpha * currgrd   ), method='Golden')
              alphaopt = res.x
              logging.info("optimal alpha: " + str(alphaopt) )               
              ims[:,:,ix+1] = ims[:,:,ix] + alphaopt*currgrd.reshape([252,308])
              regval = reg2eval(ims[:,:,ix+1].flatten())
              logging.info("regval: " + str(regval))             
          return ims[:,:,-1]-np.pi 

     def N4corrf(im):
          phasetmp = np.angle(im)
          ddimcabs = np.abs(im)
          inputImage = sitk.GetImageFromArray(ddimcabs, isVector=False)
          corrector = sitk.N4BiasFieldCorrectionImageFilter();
          inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)
          output = corrector.Execute(inputImage)
          N4biasfree_output = sitk.GetArrayFromImage(output)          
          n4biasfield = ddimcabs/(N4biasfree_output+1e-9)
          
          if np.isreal(im).all():
               return n4biasfield, N4biasfree_output 
          else:
               return n4biasfield, N4biasfree_output*np.exp(1j*phasetmp)
     
     # ===============================
     # make the data
     # ===============================     
     uspat = np.abs(us_ksp_r2) > 0
     uspat = uspat[:,:,0]
     data = us_ksp_r2
     trpat = np.zeros_like(uspat)
     trpat[:, 120:136] = 1
          
     logging.info(uspat)
     
     # ===================================== 
     # initialize counters
     # ===================================== 
     numiter = num_iter
     multip = 0 # 0.1
     alphas = stepsize*np.ones(numiter) # np.logspace(-4,-4,numiter)
     
     # ===================================== 
     # functions for POCS
     # ===================================== 
     def feval(im):
          return full_funceval(im)
     
     def geval(im):
          t1, t2, t3 = full_gradient(im)
          return np.reshape(t1,[-1]), np.reshape(t2,[-1]), np.reshape(t3,[-1])
     
     # =====================================
     # initialize data
     # =====================================
     recs = np.zeros((imsizer*imrizec, numiter+30), dtype=complex) 
     phaseregvals = []
     n4biasfields = []
     
     # =====================================
     # first image in the 'recs' list is the initial undersampled image
     # =====================================
     recs[:, 0] = utils.tUFT_with_sensmaps(data, uspat, sensmaps).flatten().copy() 
     
     if N4BFcorr:
          n4bf, N4bf_image = N4corrf( np.reshape(recs[:,0],[imsizer,imrizec]) )
          recs[:,0] = N4bf_image.flatten()
     else:
          n4bf = 1
               
     # =====================================
     # If the optimization is to be continued from a previous run,
     # load the final image from the previous optimization as the first image for the current optimization.
     # =====================================
     logging.info('contRec is ' + contRec)
     if contRec != '':
          try:
               logging.info('Reading from a previous pickle file ' + contRec)
               rr = pickle.load(open(contRec, 'rb'))
               recs[:, 0] = rr[:, -1]
               logging.info('Initialized to the previous recon from pickle: ' + contRec)
          except:
               logging.info('Reading from a previous numpy file ' + contRec)
               rr = np.load(contRec)
               recs[:, 0] = rr[:, -1]
               logging.info('Initialized to the previous recon from numpy: ' + contRec)
     
     # =====================================
     # main loop
     # we don't do GD, we do POCS (projection ontol convex sets)
     #     o. N1 times gradient updates for min_|phi| -ELBO(|x|)
     #     a. N2 times gradient updates for min_|x| -ELBO(|x|)
     #     b. do data consistency projection into a set of x such that || Ex - y ||_2^2 = 0.
     #     c. N gradient updates for the phase image ps: min_px
     # =====================================
     for it in range(0, numiter-1, 13):         
        
          alpha = alphas[it]
          
          # ===============================================
          # first do N1 times magnitude prior iterations wrt normalization module
          # ===============================================      
          recstmp = recs[:, it].copy()          
          for ix in range(n1):
          
               update_normalizer(recstmp,
                                 sess,
                                 norm_accum_grads_zero_op,
                                 norm_accum_grads_op,
                                 norm_accum_grads_mean_op,
                                 num_accum_steps_pl,
                                 norm_update_op)
                              
               ftot, f_lik, f_dc = feval(recstmp)
          
               if N4BFcorr:
                    f_dc = dconst(recstmp.reshape([imsizer, imrizec]) * n4bf)
               
               gtot, g_lik, g_dc = geval(recstmp) # gradient evaluation: total, wrt_vae, wrt_data_consistency
               
               logging.info("----------- updating normalization module, iteration number: " + str(it+ix))
               f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
               summary_writer.add_summary(f_summary_msg, it+ix)
               g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
               summary_writer.add_summary(g_summary_msg, it+ix)

               # recstmp will not change in this loop. Only the normalization module will.
               recs[:, it+ix+1] = recstmp.copy()
               
          # ===============================================
          # now do N2 times magnitude prior iterations wrt the image itself
          # ===============================================      
          recstmp = recs[:, it + n1].copy()          
          for ix in range(n2):
          
               ftot, f_lik, f_dc = feval(recstmp)
          
               if N4BFcorr:
                    f_dc = dconst(recstmp.reshape([imsizer, imrizec]) * n4bf)
               
               gtot, g_lik, g_dc = geval(recstmp) # gradient evaluation : total, wrt_vae, wrt_data_consistency
               
               logging.info("----------- updating image, iteration number: " + str(it+n1+ix))
               f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
               summary_writer.add_summary(f_summary_msg, it+n1+ix)
               g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
               summary_writer.add_summary(g_summary_msg, it+n1+ix)
     
               # recstmp will change in this loop. The normalization module will stay fixed.          
               recstmp = recstmp - alpha * g_lik # g_lik
               recs[:, it + ix + n1 + 1] = recstmp.copy()
     
          # =============================================== 
          # Now do a  DC projection.... ACTUALLY, skip the DC projection for now.
          # ===============================================
          logging.info("dummy step, iteration number: " + str(it+n1+n2+1))
          recs[:, it + n1 + n2 + 1] = recs[:, it + n1 + n2] 
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+1)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+1)
           
          # ===============================================
          # now do a phase projection
          # ===============================================
          tmpa = np.abs(np.reshape(recs[:, it + n1 + n2 + 1], [imsizer, imrizec]))
          tmpp = np.angle(np.reshape(recs[:, it + n1 + n2 + 1], [imsizer, imrizec]))
          tmpatv = tmpa.copy().flatten()
           
          if reglmb == 0:
               logging.info("skipping phase proj, iteration number: " + str(it+n1+n2+2))
               tmpptv = tmpp.copy().flatten()
               
          else:
               logging.info("doing phase proj, iteration number: " + str(it+n1+n2+2))
               if regtype == 'TV': # Total variation
                    tmpptv = tv_proj(tmpp, mu=0.125,lmb=reglmb,IT=regiter).flatten() # 0.1, 15

               elif regtype == 'reg2': # Tikhonov
                    tmpptv = reg2_proj(tmpp, alpha=reglmb, niter=100).flatten() # 0.1, 15
                    regval = reg2eval(tmpp)
                    phaseregvals.append(regval)
                    logging.info("KCT-dbg: phase reg value is " + str(regval))
                
               elif regtype == 'reg2_dc': # Tikhonov, with additional constraint from data consistency
                    tmpptv = reg2_dcproj(tmpp, tmpa, n4bf, alpha_reg=reglmb, alpha_dc=reglmb, niter=100).flatten()
                
               elif regtype == 'abs':
                    tmpptv=np.zeros_like(tmpp).flatten()
               
               elif regtype == 'reg2_ls':
                    tmpptv = reg2_proj_ls(tmpp, niter=regiter).flatten() #0.1, 15
                    regval = reg2eval(tmpp)
                    phaseregvals.append(regval)
                    logging.info("KCT-dbg: phase reg value is " + str(regval))
                
               else:
                    logging.info("hey mistake!!!!!!!!!!")
           
          # recombine magnitude and updated phase.
          recs[:, it + n1 + n2 + 2] = tmpatv*np.exp(1j*tmpptv)

          # add summary after phase projection step
          recstmp = recs[:, it + n1 + n2 + 2].copy()        
          ftot, f_lik, f_dc = feval(recstmp)
          gtot, g_lik, g_dc = geval(recstmp)
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+2)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+2)

          # ===============================================      
          # now do a data consistency projection
          # take the measured part of the k space from the measured data and the remaining part of the k space from the updated image.
          # ===============================================
          logging.info("doing data consistency step, iteration number: " + str(it+n1+n2+3))
          if not N4BFcorr:  
               tmp1 = utils.UFT_with_sensmaps(np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]), (1-uspat), sensmaps)
               tmp2 = utils.UFT_with_sensmaps(np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]), (uspat), sensmaps)
               tmp3 = data * uspat[:,:,np.newaxis]
               
               tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
               recs[:, it + n1 + n2 + 3] = utils.tFT_with_sensmaps(tmp, sensmaps).flatten()
               
               # ftot, f_lik, f_dc = feval(recs[:, it + n1 + n2 + 3])
               # logging.info('f_dc (1e6): ' + str(f_dc/1e6) + '  perc: ' + str(100*f_dc/np.linalg.norm(data)**2))
               
          elif N4BFcorr:               
               
               n4bf_prev = n4bf.copy()
               
               imgtmp = np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]) # biasfree
               
               imgtmp_bf = imgtmp * n4bf_prev # img with bf
               
               n4bf, N4bf_image = N4corrf(imgtmp_bf) # correct the bf, this correction is supposed to be better now.
               
               imgtmp_new = imgtmp * n4bf
               
               n4biasfields.append(n4bf)
               
               tmp1 = utils.UFT_with_sensmaps(imgtmp_new, (1-uspat), sensmaps)
               tmp3 = data * uspat[:, :, np.newaxis]
               tmp = tmp1 + (1 - multip) * tmp3 # multip=0 by default
               recs[:, it + n1 + n2 + 3] = (utils.tFT_with_sensmaps(tmp, sensmaps) / n4bf).flatten()
               
               # ftot, f_lik, f_dc = feval(recs[:, it + n1 + n2 + 3])
               # if N4BFcorr: f_dc = dconst(recs[:, it + n1 + n2 + 3].reshape([imsizer,imrizec])*n4bf)               
               # logging.info('f_dc (1e6): ' + str(f_dc/1e6) + '  perc: ' + str(100*f_dc / np.linalg.norm(data) ** 2))
               
          # add summary after data consistency step
          recstmp = recs[:, it + n1 + n2 + 3].copy()        
          ftot, f_lik, f_dc = feval(recstmp)
          gtot, g_lik, g_dc = geval(recstmp)
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+3)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+3)
          summary_writer.flush()
              
     return recs, 0, phaseregvals, n4biasfields