Example #1
0
    def update_patcher(self):
        """ initialize and update patcher """
        def need_reget_label():
            """ check if need to get label from patcher """
            # check if haven't got labels from patcher
            if not "labels" in self.database[self.index]:
                return True
            # check if have got labels from patcher, but with None label file
            # note: there is a case that this will lead to excess label getting,
            # that is, patcher.labels contains no labels.
            count = 0
            for class_i,boxes in self.database[self.index]["labels"].items():
                count += len(boxes)
            return count == 0

        # in case when patcher initialized without label file, need to reinitialize
        if need_reget_label():
            self.patcher = Patcher(self.database[self.index]["fname"], self.database[self.index]["lname"])
            self.database[self.index]["labels"] = self.patcher.get_labels()
Example #2
0
    def __init__(self):
        self.trainptchs = []
        self.traincount = 1
        self.testptchs = []
        self.testcount = 1
        sumptraintchs = []
        sumptesttchs = []

        orgtraindata = getData(trnTst='training')
        print(orgtraindata.shape)
        orgtestdata = getData(trnTst='testing')

        print('orgtraindata', orgtraindata.dtype, type(orgtraindata))

        Ptchr = Patcher(imsize=[orgtraindata.shape[1], orgtraindata.shape[2]],
                        patchsize=28,
                        step=int(28 / 2),
                        nopartials=True,
                        contatedges=True)

        #Ptchr1=Patcher(imsize=[orgtestdata.shape[1],orgtestdata.shape[2]],patchsize=28,step=int(28/2), nopartials=True, contatedges=True)
        for i in range(500):
            ptchs_1 = Ptchr.im2patches(
                add_Usp_nous(orgtraindata[i, :, :, :]).reshape([256, 256]))
            sumptraintchs += ptchs_1
            #orgtraindata[i,:,:,0] = np.reshape(add_Usp_nous(orgtraindata[i,:,:,:]).reshape([64,64]), [64,64])
        for i in range(4):
            ptchs_1 = Ptchr.im2patches(
                add_Usp_nous(orgtestdata[i, :, :, :]).reshape([256, 256]))
            sumptesttchs += ptchs_1
        sumptraintchs = np.array(sumptraintchs)
        sumptraintchs = sumptraintchs.reshape(-1, 784)
        print(sumptraintchs.shape)
        self.trainptchs = sumptraintchs

        sumptesttchs = np.array(sumptesttchs)
        sumptesttchs = sumptesttchs.reshape(-1, 784)
        self.testptchs = sumptesttchs
Example #3
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

#Match all public calls.
monitormode=False;
#Match all private calls.
monitormodeprivate=False;

if __name__ == '__main__':
    print "Creating patches from unwrapped.img.";
    patcher=Patcher("unwrapped.img");
    

#     #These aren't quite enough to skip the Color Code check.  Not sure why.
    patcher.nopout(0x0803ea62,0xf040);  #Main CC check.
    patcher.nopout(0x0803ea64,0x80fd);
    patcher.nopout(0x0803e994,0xf040);  #Late Entry CC check.
    patcher.nopout(0x0803e996,0x8164);
    patcher.nopout(0x0803fd98);  #dmr_dll_parser CC check.
    patcher.nopout(0x0803fd9a);
    patcher.sethword(0x0803fd8e,0xe02d, #Check in dmr_dll_parser().
                     0xd02d);
    patcher.nopout(0x0803eafe,0xf100); #Disable CRC check, in case CC is included.
    patcher.nopout(0x0803eb00,0x80af);
    
        
    # Patches after here allow for an included applet.
    
Example #4
0
def vaerecon(us_ksp_r2,
             sensmaps,
             dcprojiter,
             n=10,
             lat_dim=60,
             patchsize=28,
             contRec='',
             parfact=10,
             num_iter=302,
             rescaled=False,
             half=False,
             regiter=15,
             reglmb=0.1,
             regtype='reg2_dc',
             usemeth=1,
             stepsize=1e-4,
             optScale=False,
             mode=[],
             chunks40=False,
             Melmodels='',
             N4BFcorr=False,
             z_multip=1.0,
             directapprox=0,
             vae_model='',
             logdir='',
             directapp=0,
             gt=None):
    print('xxxxxxxxxxxxxxxxxxx contRec is ' + contRec)
    print('xxxxxxxxxxxxxxxxxxx parfact is ' + str(parfact))
    import pickle

    # set parameters
    # ==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]
    imrizec = us_ksp_r2.shape[1]

    nsampl = 50  # 0

    # make a network and a patcher to use later
    # ==============================================================================

    x_rec, x_inp, funop, grd0, grd_dir, sess, grd_p_x_z0, grd_p_z0, grd_q_z_x0, grd20, y_out, y_out_prec, z_std_multip, op_q_z_x, mu, std, grd_q_zpl_x_az0, op_q_zpl_x, z_pl, z = definevae(
        lat_dim=lat_dim,
        patchsize=patchsize,
        mode=mode,
        vae_model=vae_model,
        batchsize=parfact * nsampl)

    if directapp:
        print('_____DIRECT APPROX_____')
        grd0 = grd_dir

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)

    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    # define the necessary functions
    # ==============================================================================

    def FT(x):
        # inp: [nx, ny]
        # out: [nx, ny, ns]
        return np.fft.fftshift(np.fft.fft2(
            sensmaps * np.tile(x[:, :, np.newaxis], [1, 1, sensmaps.shape[2]]),
            axes=(0, 1)),
                               axes=(0, 1))

    # def tFT(x):
    #        # inp: [nx, ny, ns]
    #        # out: [nx, ny]
    #        tft_x = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1)) * np.conjugate(sensmaps)
    #
    #        rss = np.sqrt(np.sum(np.square(tft_x), axis=2))
    #
    #        rss = rss / (np.sqrt(np.sum(np.square(sensmaps*np.conjugate(sensmaps)),axis=2)) + 0.00000001)
    #
    #        return rss # root-sum-squared

    def tFT(x):
        # inp: [nx, ny, ns]
        # out: [nx, ny]

        temp = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1))
        return np.sum(temp * np.conjugate(sensmaps), axis=2) / (
            np.sum(sensmaps * np.conjugate(sensmaps), axis=2) + 0.00000001)

    def UFT(x, uspat):
        # inp: [nx, ny], [nx, ny]
        # out: [nx, ny, ns]

        return np.tile(uspat[:, :, np.newaxis],
                       [1, 1, sensmaps.shape[2]]) * FT(x)

    def tUFT(x, uspat):
        # inp: [nx, ny], [nx, ny]
        # out: [nx, ny]

        tmp1 = np.tile(uspat[:, :, np.newaxis], [1, 1, sensmaps.shape[2]])

        return tFT(tmp1 * x)

    def dconst(us):
        # inp: [nx, ny]
        # out: [nx, ny]

        return np.linalg.norm(UFT(us, uspat) - data)**2

    def dconst_grad(us):
        # inp: [nx, ny]
        # out: [nx, ny]
        return 2 * tUFT(UFT(us, uspat) - data, uspat)

    def likelihood(us):
        # inp: [parfact,ps*ps]
        # out: parfact
        us = np.abs(us)
        funeval = funop.eval(feed_dict={
            x_rec: np.tile(us, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(us,(nsampl,1))
        # funeval: [500x1]
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def likelihood_grad(us):
        # inp: [parfact, ps*ps]
        # out: [parfact, ps*ps]
        usc = us.copy()
        usabs = np.abs(us)

        grd0eval = grd0.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        # grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]

        sigmaeval = y_out_prec.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        sigmaeval = np.array(np.split(sigmaeval, nsampl,
                                      axis=0))  # [nsampl x parfact x 784]

        mueval = y_out.eval(feed_dict={
            x_rec: np.tile(usabs, (nsampl, 1)),
            z_std_multip: z_multip
        })  # ,x_inp: np.tile(usabs,(nsampl,1))
        mueval = np.array(np.split(mueval, nsampl,
                                   axis=0))  # [nsampl x parfact x 784]

        #vareval = np.std(mueval, axis=0)  # V(MU(X))
        #vareval = np.mean(1/sigmaeval, axis=0)  # M(SIGMA)
        vareval = np.std(grd0eval, axis=0)  # V(SIGMA (X-MU(X)))

        # grd0_var = np.std(grd0eval, axis=0)
        grd0m = np.mean(grd0eval, axis=0)  # [parfact,784]

        #grd0m = usc / np.abs(usc) * grd0m
        where_not_0 = np.where(usc > 0)
        div = usc
        div[where_not_0] = usc[where_not_0] / np.abs(usc)[where_not_0].astype(
            'float')

        grd0m = div * grd0m
        var0m = vareval

        return grd0m, var0m  # .astype(np.float64)

    def likelihood_grad_meth3(us):
        # inp: [parfact, ps*ps]
        # out: [parfact, ps*ps]
        usc = us.copy()
        usabs = np.abs(us)

        mueval = mu.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                    })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(mueval)
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(mueval.shape)
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")
        #          print(len(mueval))
        #          print("===============================================================")
        #          print("===============================================================")
        #          print("===============================================================")

        stdeval = std.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                      })  # ,x_inp: np.tile(usabs,(nsampl,1))

        zvals = mueval + np.random.rand(mueval.shape[0],
                                        mueval.shape[1]) * stdeval

        y_outeval = y_out.eval(feed_dict={z: zvals})
        y_out_preceval = y_out_prec.eval(feed_dict={z: zvals})

        tmp = np.tile(usabs, (nsampl, 1)) - y_outeval
        tmp = (-1) * tmp * y_out_preceval

        # grd0eval: [500x784]
        grd0eval = np.array(np.split(tmp, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  # [parfact,784]

        where_not_0 = np.where(usc > 0)
        div = usc
        div[where_not_0] = usc[where_not_0] / np.abs(usc)[where_not_0]
        grd0m = div * grd0m

        return grd0m  # .astype(np.float64)

    def likelihood_grad_patches(ptchs):
        # inp: [np, ps, ps]
        # out: [np, ps, ps]
        # takes set of patches as input and returns a set of their grad.s
        # both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex64)

        grds_vars = grds.copy()

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            if usemeth == 1:
                grds[parfact * ix:parfact * ix +
                     parfact, :], grds_vars[parfact * ix:parfact * ix +
                                            parfact, :] = likelihood_grad(
                                                ptchs[parfact *
                                                      ix:parfact * ix +
                                                      parfact, :])
            else:
                assert (1 == 0)

        grds = grds[0:shape_orig[0], :]

        grds_vars = grds_vars[0:shape_orig[0], :]

        return np.reshape(grds, shape_orig), np.reshape(grds_vars, shape_orig)

    def likelihood_patches(ptchs):
        # inp: [np, ps, ps]
        # out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = likelihood(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        # inp: [nx*nx, 1]
        # out: [nx, ny], [nx, ny]

        # returns both gradients in the respective positive direction.
        # i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        grd_lik, grd_lik_var = likelihood_grad_patches(ptchs)
        grd_lik = (-1) * Ptchr.patches2im(grd_lik)
        grd_lik_var = Ptchr.patches2im(grd_lik_var)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_lik + grd_dconst, grd_lik, grd_dconst, grd_lik_var

    def full_funceval(image):
        # inp: [nx*nx, 1]
        # out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        lik = (-1) * likelihood_patches(np.abs(ptchs))

        return lik + dc, lik, dc

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):
        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):
            # grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            # thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

            # endfor

        return im - fdivg(us[:, :, :, it + 1])

    def g_tv_eval(x):
        x_re = np.fft.fftshift(np.reshape(x, (imsizer, imrizec, 1)),
                               axes=(0, 1))

        data = tf.placeholder(tf.float64, shape=x_re.shape)

        x_tv = tf.image.total_variation(data)
        var_grad = tf.gradients(x_tv, [data])[0]

        var_grad_val = var_grad.eval(feed_dict={data: x_re})

        return np.fft.ifftshift(var_grad_val, axes=(0, 1))

    def tv_norm(x):
        """Computes the total variation norm and its gradient. From jcjohnson/cnn-vis."""
        x = np.fft.fftshift(np.reshape(x, (imsizer, imrizec, 1)), axes=(0, 1))

        x_diff = x - np.roll(x, -1, axis=1)
        y_diff = x - np.roll(x, -1, axis=0)
        grad_norm2 = x_diff**2 + y_diff**2 + np.finfo(np.float32).eps
        norm = np.sum(np.sqrt(grad_norm2))
        dgrad_norm = 0.5 / np.sqrt(grad_norm2)
        dx_diff = 2 * x_diff * dgrad_norm
        dy_diff = 2 * y_diff * dgrad_norm
        grad = dx_diff + dy_diff
        grad[:, 1:, :] -= dx_diff[:, :-1, :]
        grad[1:, :, :] -= dy_diff[:-1, :, :]

        return norm, np.reshape(np.fft.ifftshift(grad, axes=(0, 1)), [-1])

    # make the data
    # ===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :, 0]
    data = us_ksp_r2

    trpat = np.zeros_like(uspat)
    trpat[:, 120:136] = 1

    # lrphase = np.angle( tUFT(data*trpat[:,:,np.newaxis],uspat) )
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lowresphase','rb'))
    # truephase = pickle.load(open('/home/ktezcan/unnecessary_stuff/truephase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/usphase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lrusphase','rb'))
    # lrphase = pickle.load(open('/home/ktezcan/unnecessary_stuff/lrmaskphase','rb'))

    # make the functions for POCS
    # =====================================
    numiter = num_iter

    multip = 0  # 0.1

    alphas = stepsize * np.ones(numiter)  # np.logspace(-4,-4,numiter)

    #     alphas=np.ones_like(np.logspace(-4,-4,numiter))*5e-3

    def feval(im):
        return full_funceval(im)

    def geval(im):
        t1, t2, t3, t4 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(
            t3, [-1]), np.reshape(t4, [-1])

    # initialize data
    recs = np.zeros((imsizer * imrizec, numiter + 2), dtype=complex)

    #     recs[:,0] = np.abs(tUFT(data, uspat).flatten().copy()) #kct

    recs[:, 0] = tUFT(data, uspat).flatten().copy()

    #pickle.dump(recs[:, 0], open(logdir + '_rec_0', 'wb'))
    n4bf = 1

    #     recs[:,0] = np.abs(tUFT(data, uspat).flatten().copy() )*np.exp(1j*lrphase).flatten()

    phaseregvals = []

    # pickle.dump(recs[:,0],open('/scratc_','wb'))

    print('contRec is ' + contRec)
    if contRec != '':
        try:
            print('KCT-INFO: reading from a previous pickle file ' + contRec)
            import pickle
            rr = pickle.load(open(contRec, 'rb'))
            recs[:, 0] = rr[:, -1]
            print('KCT-INFO: initialized to the previous recon from pickle: ' +
                  contRec)
        except:
            print('KCT-INFO: reading from a previous numpy file ' + contRec)
            rr = np.load(contRec)
            recs[:, 0] = rr[:, -1]
            print('KCT-INFO: initialized to the previous recon from numpy: ' +
                  contRec)

    n4biasfields = []

    recsarr = []

    for it in range(0, numiter - 2, 2):
        alpha = alphas[it]

        # first do N times magnitude prior iterations
        # ===============================================
        # ===============================================

        recstmp = recs[:, it].copy()

        ftot, f_lik, f_dc = 0, 0, 0  #feval(recstmp)

        gtot, g_lik, g_dc, g_lik_var = geval(recstmp)

        tvnorm, tvgrad = tv_norm(np.abs(recstmp))

        lambda_lik = 0
        lambda_reg = 1
        recstmp_1 = recstmp - alpha * (lambda_lik * g_lik +
                                       lambda_reg * tvgrad)

        recs[:, it + 1] = recstmp_1.copy()

        print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_lik= " +
              str(f_lik) + ' TV norm= ' + str(tvnorm) + " f_dc (1e6)= " +
              str(f_dc / 1e6) + " |g_lik|= " + str(np.linalg.norm(g_lik)) +
              " |g_dc|= " + str(np.linalg.norm(g_dc)) + ' |g_tv|= ' +
              str(np.linalg.norm(tvgrad)))

        # if it == 0:
        #      pickle.dump(recstmp, open(logdir + '_rec_0', 'wb'))
        #      pickle.dump(g_lik, open(logdir + '_rec_likgrad', 'wb'))
        #      pickle.dump(g_dc, open(logdir + '_rec_dcgrad', 'wb'))
        #      pickle.dump(tvgrad, open(logdir + '_rec_tvgrad', 'wb'))
        #      pickle.dump(tvgrad * g_lik_var, open(logdir + '_rec_tvmulvar', 'wb'))
        #      pickle.dump(g_lik_var, open(logdir + '_rec_var', 'wb'))
        #      exit()
        # now do again a data consistency projection
        # ===============================================
        # ===============================================

        tmp1 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                   (1 - uspat))
        tmp2 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]), (uspat))
        tmp3 = data * uspat[:, :, np.newaxis]

        tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
        recs[:, it + 2] = tFT(tmp).flatten()

        #ftot, f_lik, f_dc = feval(recs[:, it + 2])
        #print('f_dc (1e6): ' + str(f_dc / 1e6) + '  perc: ' + str(100 * f_dc / np.linalg.norm(data) ** 2))

        # MSE CHECK
        recon_sli = np.reshape(recs[:, it + 2], (imsizer, imrizec))
        gt = np.reshape(gt, (imsizer, imrizec))

        rss = np.sqrt(
            np.sum(np.square(
                np.abs(sensmaps * np.tile(recon_sli[:, :, np.newaxis],
                                          [1, 1, sensmaps.shape[2]]))),
                   axis=-1))

        nmse = np.sqrt(((np.fft.fftshift(rss) - gt)**2).mean()) / np.sqrt(
            ((gt)**2).mean())
        print('NMSE: ', nmse)

    return recs, 0, phaseregvals, n4biasfields
Example #5
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Vocoder Patch for MD380 Firmware
# Applies to version D013.020

from Patcher import Patcher

# Match all public calls.
monitormode = False
# Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print("Creating patches from unwrapped.img.")
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on D013.020

    patcher.nopout((0x08033f30 + 0x18))
    patcher.nopout((0x08033f30 + 0x1a))
    patcher.nopout((0x08033f30 + 0x2e))
    patcher.nopout((0x08033f30 + 0x30))
    patcher.nopout((0x08033f30 + 0x44))
    patcher.nopout((0x08033f30 + 0x46))
    patcher.nopout((0x08033f30 + 0x5a))
    patcher.nopout((0x08033f30 + 0x5c))
    patcher.nopout((0x08033f30 + 0x70))
    patcher.nopout((0x08033f30 + 0x72))
    patcher.nopout((0x08033f30 + 0x86))
    patcher.nopout((0x08033f30 + 0x88))
Example #6
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Vocoder Patch for MD380 Firmware
# Applies to version D013.020

from Patcher import Patcher

# Match all public calls.
monitormode = False
# Match all private calls.
monitormodeprivate = True

if __name__ == '__main__':
    print("Creating patches from unwrapped.img.")
    patcher = Patcher("unwrapped.img")

    #test gps
    #patcher.nopout((0x800C278 + 0))
    #patcher.nopout((0x800C278 + 2))

    # bypass vocoder copy protection on D013.020

    #patcher.nopout((0x08011444))
    #patcher.nopout((0x08011444) + 0x2)

    #test manual dial group callable
    #patcher.sethword(0x08023170, 0x2204)
    #patcher.sethword(0x08012912, 0x2804)
    #patcher.sethword(0x080EB1B0, 0x00FF)
    #patcher.nopout((0x08028F88))
Example #7
0
#!/usr/bin/env python
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

if __name__ == '__main__':
    print "Creating patches from unwrapped.img.";
    patcher=Patcher("unwrapped.img");
    
    #Old logo patcher, no longer used.
    #fhello=open("welcome.txt","rb");
    #hello=fhello.read();
    #patcher.str2sprite(0x08094610,hello);
    #print patcher.sprite2str(0x08094610,0x14,760);
    
    #Old patch, matching on the first talkgroup.
    #We don't use this anymore, because the new patch is better.
    #patcher.nopout(0x0803ee36,0xd1ef);
    
    # New patch for monitoring all talk groups , matched on first
    # entry iff no other match.
    #wa mov r5, 0 @ 0x0803ee86 # So the radio thinks it matched at zero.
    patcher.sethword(0x0803ee86, 0x2500);
    #wa b 0x0803ee38 @ 0x0803ee88 # Branch back to perform that match.
    patcher.sethword(0x0803ee88,0xe7d6); #Jump back to matched condition.
    
    patcher.export("prom-public.img");
    
    # This should be changed to only show missed calls for private
    # calls directed at the user, and to decode others without
Example #8
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Vocoder Patch for MD380 Firmware
# Applies to version D013.020

from Patcher import Patcher

# Match all public calls.
monitormode = False
# Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print("Creating patches from unwrapped.img.")
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on D013.020

    patcher.nopout((0x08033f30 + 0x18))
    patcher.nopout((0x08033f30 + 0x1a))
    patcher.nopout((0x08033f30 + 0x2e))
    patcher.nopout((0x08033f30 + 0x30))
    patcher.nopout((0x08033f30 + 0x44))
    patcher.nopout((0x08033f30 + 0x46))
    patcher.nopout((0x08033f30 + 0x5a))
    patcher.nopout((0x08033f30 + 0x5c))
    patcher.nopout((0x08033f30 + 0x70))
    patcher.nopout((0x08033f30 + 0x72))
    patcher.nopout((0x08033f30 + 0x86))
    patcher.nopout((0x08033f30 + 0x88))
Example #9
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode = False
#Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on S013.020
    patcher.nopout((0x8034a60))
    patcher.nopout((0x8034a60 + 0x2))
    patcher.nopout((0x8034a76))
    patcher.nopout((0x8034a76 + 0x2))
    patcher.nopout((0x8034a8c))
    patcher.nopout((0x8034a8c + 0x2))
    patcher.nopout((0x8034aa2))
    patcher.nopout((0x8034aa2 + 0x2))
    patcher.nopout((0x8034ab8))
    patcher.nopout((0x8034ab8 + 0x2))
    patcher.nopout((0x8034ace))
    patcher.nopout((0x8034ace + 0x2))
    patcher.nopout((0x8049f9a))
    patcher.nopout((0x8049f9a + 0x2))
    patcher.nopout((0x804a820))
Example #10
0
    arg_enum.append("	%s = '%s'," % (enum_name, optchar))
arg_enum.append("")
first = True
for (option, enum_name, requires_parameter) in opts_long:
    if first:
        arg_enum.append("	%s = 1000," % (enum_name))
        first = False
    else:
        arg_enum.append("	%s," % (enum_name))
arg_enum.append("};")
arg_enum = "\n".join(arg_enum) + "\n"

cmd_def = ["	const char *short_options = \"%s\";" % (short_string)]
cmd_def.append("	struct option long_options[] = {")
for (option, enum_name, requires_parameter) in opts_long:
    param = "\"%s\"," % (option)
    if requires_parameter:
        cmd_def.append("		{ %-30s required_argument, 0, %s }," %
                       (param, enum_name))
    else:
        cmd_def.append("		{ %-30s no_argument,       0, %s }," %
                       (param, enum_name))
cmd_def.append("		{ 0 }")
cmd_def.append("	};")
cmd_def = "\n".join(cmd_def) + "\n"

patcher = Patcher("../pgmopts.c")
patcher.patch("help page", help_code)
patcher.patch("command definition enum", arg_enum)
patcher.patch("command definition", cmd_def)
Example #11
0
def vaerecon(us_ksp_r2,
             uspat,
             orim,
             i,
             result_all,
             method,
             dcprojiter,
             onlydciter=0,
             lat_dim=60,
             patchsize=28,
             contRec='',
             lowresmodel=False,
             parfact=10,
             num_iter=1,
             regiter=15,
             reglmb=0.05,
             regtype='TV'):
    def write_Data(result_all):
        with open(os.path.join(savepath, "psnr_" + method + ".txt"),
                  "w+") as f:
            #print(len(result_all))
            for i in range(len(result_all)):
                f.writelines('current image {} PSNR : '.format(i) + str(result_all[i][0]) + \
                "    SSIM : " + str(result_all[i][1]) + "    HFEN : " + str(result_all[i][2]))
                f.write('\n')

    print('KCT-INFO: contRec is ' + contRec)
    print('KCT-INFO: parfact is ' + str(parfact))
    #print("............",us_ksp_r2.shape)
    #print(sensmaps.shape)
    # set parameters
    #==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]  #252#256#252
    imrizec = us_ksp_r2.shape[1]  #308#256#308
    #print(imsizer,imrizec)
    nsampl = 50

    #make a network and a patcher to use later
    #==============================================================================
    x_rec, x_inp, funop, grd0, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = definevae(
        lat_dim=lat_dim, patchsize=patchsize, batchsize=parfact * nsampl)

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)
    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    def dconst(us):
        #inp: [nx, ny]
        #out: [nx, ny]

        return np.linalg.norm(np.fft.fft2(us) * uspat - data)**2

    def dconst_grad(us):
        #inp: [nx, ny]
        #out: [nx, ny]
        return 2 * np.fft.ifft2(uspat * (np.fft.fft2(us) * uspat - data))

    def prior(us):
        #inp: [parfact,ps*ps]
        #out: parfact

        us = np.abs(us)
        funeval = funop.eval(feed_dict={x_rec: np.tile(us, (nsampl, 1))})  #
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def prior_grad(us):
        #inp: [parfact, ps*ps]
        #out: [parfact, ps*ps]
        #print('prior_grad',us.dtype,np.max(us.imag))
        usc = us.copy()
        usabs = np.abs(us)
        #print('prior_grad usabs',usabs.dtype,np.max(usabs.imag))
        #print('prior_grad usc',usc.dtype,np.max(usc.imag))
        #print('usc == us ? :',usc == us)
        grd0eval = grd0.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                        })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  #[parfact,784]

        #print('grd0eval.shape,grd0m.shape',grd0eval.shape,grd0m.shape)
        #print(grd0m)

        grd0m = usc / np.abs(usc) * grd0m

        #print('prior_grad grd0m',grd0m.dtype,np.max(grd0m.imag))

        return grd0m  # .astype(np.float64)

    def prior_grad_patches(ptchs):
        #inp: [np, ps, ps]
        #out: [np, ps, ps]
        #takes set of patches as input and returns a set of their grad.s
        #both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex128)

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        #print('prior_grad_patches',ptchs.dtype,np.max(ptchs.imag))
        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            grds[parfact * ix:parfact * ix + parfact, :] = prior_grad(
                ptchs[parfact * ix:parfact * ix + parfact, :])

        grds = grds[0:shape_orig[0], :]

        print('np.reshape(grds, shape_orig).shape',
              np.reshape(grds, shape_orig).shape)
        return np.reshape(grds, shape_orig)
        #np.ceil   shang qu zheng

    def prior_patches(ptchs):
        #inp: [np, ps, ps]
        #out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = prior(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        #inp: [nx*nx, 1]
        #out: [nx, ny], [nx, ny]

        #returns both gradients in the respective positive direction.
        #i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)
        #print('ptchs dtype',ptchs.dtype,np.max(ptchs))

        grd_prior = prior_grad_patches(ptchs)
        grd_prior = (-1) * Ptchr.patches2im(grd_prior)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_prior + grd_dconst, grd_prior, grd_dconst

    def full_funceval(image):
        #inp: [nx*nx, 1]
        #out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        priorval = (-1) * prior_patches(np.abs(ptchs))

        return priorval + dc, priorval, dc

    #define the phase regularization functions
    #==============================================================================

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        # Total variation based projection

        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        # gradient operation with 1st order finite differences
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        # divergence operator with 1st order finite differences
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):
        # soft thresholding

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):

        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):

            #grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            #thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

        #endfor

        return im - fdivg(us[:, :, :, IT - 1])

    def reg2_proj(usph, niter=100, alpha=0.05):
        #A smoothness based based projection. Regularization method 2 from
        #"Separate Magnitude and Phase Regularization via Compressed Sensing",  Feng Zhao et al, IEEE TMI, 2012

        usph = usph + np.pi * 3 / 2

        ims = np.zeros((imsizer, imrizec, niter))
        ims[:, :, 0] = usph.copy()
        for ix in range(niter - 1):
            ims[:, :, ix + 1] = ims[:, :, ix] - 2 * alpha * np.real(
                1j * np.exp(-1j * ims[:, :, ix]) *
                fdivg(fgrad(np.exp(1j * ims[:, :, ix]))))

        return ims[:, :, -1] - np.pi * 3 / 2

    #make the data
    #===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :]
    data = us_ksp_r2

    import pickle

    #make the functions for POCS
    #=====================================
    #number of iterations
    numiter = num_iter

    # if you want to do an affine data consistency projection
    multip = 0  # 0 means simply replacing the measured values

    # step size for the prior iterations
    alphas = np.logspace(-4, -4, numiter)

    #print('step size for the prior iterations',alphas)
    #assert False

    def feval(im):
        return full_funceval(im)

    def geval(im):
        t1, t2, t3 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(t3, [-1])

    # initialize data with the zero-filled image
    recs = np.zeros((imsizer * imrizec, numiter), dtype=complex)
    #recs[:,0] = tUFT(data, uspat).flatten().copy()

    recs[:, 0] = np.fft.ifft2(data * uspat).flatten().copy()
    recs[:, 0] = abs(recs[:, 0])
    print(
        'Zerosfilled Psnr :',
        compare_psnr(255 * np.abs(np.reshape(recs[:, 0], [imsizer, imrizec])),
                     255 * np.abs(orim),
                     data_range=255))
    #plt.figure(444444444)
    #plt.imshow(np.abs(np.reshape(recs[:,0],[imsizer,imrizec])),cmap='gray')
    #plt.show()
    # if you want to instead continue reconstruction from an existing image
    print(' KCT-INFO: contRec is ' + contRec)
    if contRec != '':
        print('KCT-INFO: reading from a previous file ' + contRec)
        rr = pickle.load(open(contRec, 'rb'))
        recs[:, 0] = rr[:, -1]
        print('KCT-INFO: initialized to the previous recon: ' + contRec)

    import time
    # the itertaion loop
    max_psnr = 0
    max_ssim = 0
    min_hfen = 100
    for it in range(numiter - 1):
        start = time.time()
        # get the step size for the Piteration
        alpha = alphas[it]

        # if you want to do some data consistency iterations before starting the prior projections
        # this can be helpful when doing recon with multiple coils, e.g. you can do pure SENSE in the beginning...
        # or only do phase projections in the beginning.
        if it > onlydciter:

            # get the gradients for the prior projection and the likelihood values
            #ftot, f_prior, f_dc = feval(recs[:,it])
            print('recs dtype Input ', recs.dtype)
            gtot, g_prior, g_dc = geval(recs[:, it])

            #                print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_prior= " + str(-f_prior) + " f_dc (x1e6)= " + str(f_dc/1e6) + " |g_prior|= " + str(np.linalg.norm(g_prior)) + " |g_dc|= " + str(np.linalg.norm(g_dc)) )
            #print("it no: {0}, f_tot= {1:.2f},f_prior= {2:.2f}, f_dc (x1e6)= {3:.2f}, |g_prior|= {4:.2f}, |g_dc|= {5:.2f}".format(it,ftot,f_prior,f_dc/1e6,np.linalg.norm(g_prior),np.linalg.norm(g_dc)))

            # update the image with the prior gradient
            recs[:, it + 1] = recs[:, it] - alpha * g_prior

            print(
                'current Phase prePsnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current Phase Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Phase Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #plt.figure(1111)
            #plt.imshow(np.abs(np.reshape(recs[:,it+1],[imsizer,imrizec])),cmap='gray')
            #plt.show()
            #print("............................................")
            #print(recs[:,it+1])
            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1],
                                       [imsizer, imrizec]))  #+(np.pi*3/2)
            print(np.max(tmpp), np.min(tmpp))
            tmpaf = tmpa.copy().flatten()

            #plt.figure(1111)
            #plt.imshow(tmpp,cmap='gray')
            #plt.show()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            #tmpptv -= (np.pi*3/2)
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)  #####here

        else:  # the case where you do only data consistency iterations (also iteration 0)
            if not it == 0:
                print(
                    'KCT-info: skipping prior proj for the first onlydciters iter.s, doing only phase proj (then maybe DC proj as well) !!!'
                )

            recs[:, it + 1] = recs[:, it].copy()

            print(
                'current Phase prePsnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current Phase Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Phase Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)
            #plt.figure(22222)
            #plt.imshow(np.abs(np.reshape(recs[:,it+1],[imsizer,imrizec])),cmap='gray')
            #plt.show()
        #do the DC projection every 'dcprojiter' iterations
        if it < onlydciter + 1 or it % dcprojiter == 0:  #

            tmp_rec = np.reshape(recs[:, it + 1], [imsizer, imrizec])
            #plt.figure(3333333)
            #plt.imshow(np.abs(tmp1),cmap='gray')
            #plt.show()
            print(np.max(np.abs(tmp_rec)), np.max(255 * np.abs(orim)))
            print(
                'current Prior Psnr :',
                compare_psnr(255 * np.abs(tmp_rec),
                             255 * np.abs(orim),
                             data_range=255))
            print(
                'current Prior Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current Prior Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #tmp1 = np.fft.fft2(tmp_rec)
            #print(uspat)
            #tmp1[uspat==1] = us_ksp_r2[uspat==1]
            #tmp1 = np.fft.ifft2(tmp1)
            #recs[:,it+1] = tmp1.flatten()
            #plt.figure(1)
            #plt.imshow(np.abs(tmp1),cmap='gray')
            #plt.show()

            tmp1 = np.fft.fft2(np.reshape(recs[:, it + 1],
                                          [imsizer, imrizec])) * (1 - uspat)
            tmp2 = np.fft.fft2(np.reshape(recs[:, it + 1],
                                          [imsizer, imrizec])) * (uspat)

            tmp3 = data * uspat  #[:,:,np.newaxis]

            #combine the measured data with the projected image affinely (multip=0 for replacing the values)
            tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
            recs[:, it + 1] = np.fft.ifft2(tmp).flatten()  #np.abs(
            print('recs dtype DC ', recs.dtype)
            print(
                'current DC Psnr :',
                compare_psnr(
                    255 *
                    np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec])),
                    255 * np.abs(orim),
                    data_range=255))
            print(
                'current DC Real Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).real),
                             255 * np.abs(orim.real),
                             data_range=255))
            print(
                'current DC Imag Psnr :',
                compare_psnr(255 * np.abs(
                    np.reshape(recs[:, it + 1], [imsizer, imrizec]).imag),
                             255 * np.abs(orim.imag),
                             data_range=255))

            #ftot, f_lik, f_dc = feval(recs[:,it+1])

            #print(ftot.shape,f_lik.shape,f_dc.shape)
        x_complex = np.reshape(recs[:, it + 1], [256, 256])
        psnr = compare_psnr(255 * abs(x_complex),
                            255 * abs(orim),
                            data_range=255)
        ssim = compare_ssim(abs(x_complex), abs(orim), data_range=1)
        hfen = compare_hfen(abs(x_complex), abs(orim))
        print('current %d image %d iteration PSNR:' % (i, it), psnr, ' SSIM :',
              ssim, ' HFEN :', hfen)
        if max_psnr < psnr:
            result_all[i, 0] = psnr
            max_psnr = psnr
            result_all[31, 0] = sum(result_all[:31, 0]) / 31
            savemat(
                os.path.join(savepath,
                             'img_{}_Rec_'.format(i) + method + '.mat'),
                {
                    'data':
                    np.array(np.reshape(recs[:, it + 1], [256, 256]),
                             dtype=np.complex)
                })
            cv2.imwrite(
                os.path.join(savepath,
                             'img_{}_Rec_'.format(i) + method + '.png'),
                np.array(255 * np.abs(np.reshape(recs[:, it + 1], [256, 256])),
                         dtype=np.uint8))
        if max_ssim < ssim:
            result_all[i, 1] = ssim
            max_ssim = ssim
            result_all[31, 1] = sum(result_all[:31, 1]) / 31

        if min_hfen > hfen:
            result_all[i, 2] = hfen
            min_hfen = hfen
            result_all[31, 2] = sum(result_all[:31, 2]) / 31
        write_Data(result_all)
        #print('recs dtype DC ',recs.dtype)
        end = time.time()
        print('{} iteration time is'.format(it), end - start)
    return recs  #YourDatasetModuleHere
def vaerecon(us_ksp_r2, # undersampled k space
             sensmaps,
             dcprojiter,
             onlydciter = 0,
             lat_dim = 60,
             patchsize = 28,
             contRec = '',
             parfact = 10,
             num_iter = 302,
             rescaled = False,
             half = False,
             regiter = 15,
             reglmb = 0.05,
             regtype = 'TV',
             usemeth = 1,
             stepsize = 1e-4,
             optScale = False,
             mode = [],
             chunks40 = False,
             Melmodels = '',
             N4BFcorr = False,
             z_multip = 1.0,
             n1 = 5,
             n2 = 5,
             log_dir = ''):
     
     logging.info('xxxxxxxxxxxxxxxxxxx contRec is ' + contRec)
     logging.info('xxxxxxxxxxxxxxxxxxx parfact is ' + str(parfact) )
     
     # ==============================================================================
     # set parameters
     # ==============================================================================
     np.random.seed(seed = 1)     
     imsizer = us_ksp_r2.shape[0] #252#256#252
     imrizec = us_ksp_r2.shape[1] #308#256#308
     nsampl = 50 #0
          
     # ==============================================================================
     # make a network and a patcher to use later
     # ==============================================================================     
     # =================================
     # get the output from the VAE
     # funop: ELBO
     # grd0: gradient of the ELBO wrt the image
     # grd_p_x_z0, grd_p_z0, grd_q_z_x0: different gradients
     # =================================
     vae_outputs = definevae(lat_dim = lat_dim,
                             patchsize = patchsize,
                             batchsize = parfact*nsampl,
                             rescaled = rescaled,
                             half = half,
                             mode = mode,
                             chunks40 = chunks40,
                             Melmodels = Melmodels,
                             use_normalizer = bool(n1>0),
                             log_dir = log_dir)
     
     x_rec, x_inp, funop, grd0, sess = vae_outputs[0:5]
     grd_p_x_z0, grd_p_z0, grd_q_z_x0, grd20 = vae_outputs[5:9] # these are not being used in the code now.
     y_out, y_out_prec, z_std_multip, op_q_z_x = vae_outputs[9:13] # these are not being used in the code now.
     mu, std = vae_outputs[13:15] # these are not being used in the code now.
     grd_q_zpl_x_az0, op_q_zpl_x, z_pl, z = vae_outputs[15:19] # these are not being used in the code now.
     # Most of these outputs were being used in the function likelihood_grad_meth3, which is no longer being used.
     norm_accum_grads_zero_op, norm_accum_grads_op, norm_accum_grads_mean_op, num_accum_steps_pl, norm_update_op, x_norm = vae_outputs[19:25]
     f_elbo, f_data_consistency, f_total, f_summary = vae_outputs[25:29]
     g_elbo, g_data_consistency, g_total, g_summary, summary_writer = vae_outputs[29:34]
     
     # =================================
     # used to go from image to patches and back
     # =================================
     Ptchr = Patcher(imsize = [imsizer, imrizec],
                     patchsize = patchsize,
                     step = int(patchsize/2),
                     nopartials = True,
                     contatedges = True)
     
     nopatches = len(Ptchr.genpatchsizes)
     logging.info("There will be in total " + str(nopatches) + " patches.")
     
     # =================================
     # functions for data consistency
     # =================================
     def dconst(us):
          #inp: [nx, ny]
          #out: [nx, ny]
          return np.linalg.norm(utils.UFT_with_sensmaps(us, uspat, sensmaps) - data)**2
     
     def dconst_grad(us):
          #inp: [nx, ny]
          #out: [nx, ny]
          return 2 * utils.tUFT_with_sensmaps(utils.UFT_with_sensmaps(us, uspat, sensmaps) - data, uspat, sensmaps)

     # =================================  
     # function for running the normalization op on a batch of patches
     # =================================  
     def update_normalizer(im,
                           sess,
                           accum_gradients_zero_op,
                           accum_gradients_op,
                           accum_gradients_mean_op,
                           num_accum_steps_pl,
                           update_normalizer_op):
         
          # make patches
          ptchs = Ptchr.im2patches(np.reshape(im, [imsizer, imrizec]))
          ptchs = np.array(ptchs)          
          ptchs = np.abs(ptchs)
          
          # zero accumulated gradients
          sess.run(accum_gradients_zero_op)
          num_accumulation_steps = 0
          
          # convert image to patches
          ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])
          # extra indices that have to be padded to ensure we have complete batches for the VAE
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, ((0, extraind), (0,0)), mode = 'edge')
          
          # send batches of patches and accumulate gradients
          for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
              batch_of_patches = ptchs[parfact * ix : parfact * ix + parfact, :]
              sess.run(accum_gradients_op, feed_dict = {x_rec: np.tile(batch_of_patches, (nsampl, 1)), z_std_multip: z_multip})
              num_accumulation_steps = num_accumulation_steps + 1
              
          # run op to mean gradients accumulated so far
          sess.run(accum_gradients_mean_op, feed_dict = {num_accum_steps_pl: num_accumulation_steps})
          
          # update normalizer parameters according to the mean gradient.
          # The stuff passed in the feed_dict does not matter in this line, but something needs to be passed. So passing the last batch of patches.
          sess.run(update_normalizer_op, feed_dict = {x_rec: np.tile(batch_of_patches, (nsampl, 1)), z_std_multip: z_multip})
          
          return 0     
     
     # =================================
     # functions for computing the ELBO and its derivatives
     # =================================     
     
     # =================================     
     # returns the elbo of a batch of patches
     # =================================     
     def likelihood(us):
          # inp: [parfact,ps*ps]
          # out: parfact
          us = np.abs(us)
          funeval = funop.eval(feed_dict = {x_rec: np.tile(us, (nsampl,1)), z_std_multip: z_multip})
          funeval = np.array(np.split(funeval, nsampl, axis=0)) # [nsampl x parfact x 1]
          return np.mean(funeval, axis=0).astype(np.float64)
      
     # =================================    
     # returns the elbo of the input patches
     # =================================    
     def likelihood_patches(ptchs):
          # inp: [np, ps, ps] 
          # out: 1
          
          fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact) ])
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, [(0,extraind), (0,0), (0,0)], mode='edge')
          
          for ix in range(int(np.ceil(ptchs.shape[0]/parfact))):
               fvls[parfact*ix : parfact*ix+parfact] = likelihood(np.reshape(ptchs[parfact*ix : parfact*ix+parfact,:,:], [parfact,-1]))
               
          fvls = fvls[0:ptchs.shape[0]]
               
          return np.mean(fvls)
      
     # =================================
     # returns the ELBO of the image 
     # =================================     
     def full_funceval(image):
          #inp: [nx*nx, 1]
          #out: [1], [1], [1]
          
          tmpimg = np.reshape(image, [imsizer,imrizec])
          
          # how consistent is this image with the measured undersampled k-space data    
          dc = dconst(tmpimg) 
          
          # convert the image into patches and measure the elbo of each patch
          ptchs = Ptchr.im2patches(np.reshape(image, [imsizer,imrizec]))
          ptchs = np.array(ptchs)          
          lik = (-1)*likelihood_patches(np.abs(ptchs))    
          
          return lik + dc, lik, dc    
     
     # =================================  
     # returns the gradient of the elbo of a batch of patches wrt the batch of patches
     # =================================  
     def likelihood_grad(us):
          # inp: [parfact, ps*ps]
          # out: [parfact, ps*ps]
          usc = us.copy()
          usabs = np.abs(us)
          grd0eval = grd0.eval(feed_dict = {x_rec: np.tile(usabs, (nsampl, 1)), z_std_multip: z_multip}) # grd0eval: [500x784]
          grd0eval = np.array(np.split(grd0eval, nsampl, axis=0)) # [nsampl x parfact x 784]
          grd0m = np.mean(grd0eval,axis=0) # [parfact,784]
          # correction for the complex image
          grd0m = usc/np.abs(usc)*grd0m
          return grd0m #.astype(np.float64)
     
     # =================================  
     # returns the gradient of the elbo of the input patches wrt the input patches
     # =================================  
     def likelihood_grad_patches(ptchs):
          # inp: [np, ps, ps] 
          # out: [np, ps, ps] 
          # takes set of patches as input and returns a set of their grad.s 
          # both grads are in the positive direction
          
          shape_orig = ptchs.shape
          ptchs = np.reshape(ptchs, [ptchs.shape[0], -1] )
          grds = np.zeros([int(np.ceil(ptchs.shape[0]/parfact)*parfact), np.prod(ptchs.shape[1:])], dtype = np.complex64)
          
          # extra indices that have to be padded to ensure we have complete batches for the VAE
          extraind = int(np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
          ptchs = np.pad(ptchs, ((0, extraind), (0,0)), mode='edge')
          
          for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
              grds[parfact*ix : parfact*ix+parfact, :] = likelihood_grad(ptchs[parfact*ix : parfact*ix+parfact, :]) 

          grds = grds[0:shape_orig[0],:]

          return np.reshape(grds, shape_orig)
     

     # =================================
     # returns the gradient of the ELBO wrt the image 
     # =================================     
     def full_gradient(image):
          #inp: [nx*nx, 1]
          #out: [nx, ny], [nx, ny]
          #returns both gradients in the respective positive direction.
          #i.e. must 
          
          # convert image to patches
          ptchs = Ptchr.im2patches(np.reshape(image, [imsizer,imrizec]))
          ptchs = np.array(ptchs)          
          
          # compute gradient of the ELBO wrt each patch
          grd_lik = likelihood_grad_patches(ptchs)
          grd_lik = (-1)* Ptchr.patches2im(grd_lik)
          
          grd_dconst = dconst_grad(np.reshape(image, [imsizer,imrizec]))
          
          return grd_lik + grd_dconst, grd_lik, grd_dconst
     
     # =================================
     # phase projection functions
     # =================================     
     def tv_proj(phs,mu=0.125,lmb=2,IT=225):
          phs = fb_tv_proj(phs,mu=mu,lmb=lmb,IT=IT)          
          return phs
     
     def fgrad(im):
          imr_x = np.roll(im,shift=-1,axis=0)
          imr_y = np.roll(im,shift=-1,axis=1)
          grd_x = imr_x - im
          grd_y = imr_y - im          
          return np.array((grd_x, grd_y))
     
     def fdivg(im):
          imr_x = np.roll(np.squeeze(im[0,:,:]),shift=1,axis=0)
          imr_y = np.roll(np.squeeze(im[1,:,:]),shift=1,axis=1)
          grd_x = np.squeeze(im[0,:,:]) - imr_x
          grd_y = np.squeeze(im[1,:,:]) - imr_y          
          return grd_x + grd_y
     
     def f_st(u,lmb):          
          uabs = np.squeeze(np.sqrt(np.sum(u*np.conjugate(u),axis=0)))          
          tmp = 1 - lmb/uabs
          tmp[np.abs(tmp) < 0] = 0             
          uu = u*np.tile(tmp[np.newaxis,:,:],[u.shape[0],1,1])          
          return uu
       
     def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):
          sz = im.shape
          us=np.zeros((2,sz[0],sz[1],IT))
          us[:,:,:,0] = u0          
          for it in range(IT-1):               
               # grad descent step:
               tmp1 = im - fdivg(us[:,:,:,it])
               tmp2 = mu*fgrad(tmp1)
               tmp3 = us[:,:,:,it] - tmp2                 
               # thresholding step:
               us[:,:,:,it+1] = tmp3 - f_st(tmp3, lmb=lmb)                    
          return im - fdivg(us[:,:,:,it+1]) 
          
     def tikh_proj(usph, niter=100, alpha=0.05):          
          ims = np.zeros((imsizer, imrizec, niter))
          ims[:,:,0] = usph.copy()
          for ix in range(niter-1):
              ims[:,:,ix+1] = ims[:,:,ix] + alpha*2*fdivg(fgrad(ims[:,:,ix]))
          return ims[:,:,-1]
     
     def reg2_proj(usph, niter=100, alpha=0.05):          
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          usph = usph + np.pi
          ims = np.zeros((imsizer,imrizec,niter))
          ims[:,:,0]=usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())

          for ix in range(niter-1):
              ims[:,:,ix+1] = ims[:,:,ix] +alpha*reg2grd(ims[:,:,ix].flatten()).reshape([252,308]) # *alpha*np.real(1j*np.exp(-1j*ims[:,:,ix])*    fdivg(fgrad(np.exp(  1j* ims[:,:,ix]    )))     )
              regval = reg2eval(ims[:,:,ix+1].flatten())
          
          return ims[:,:,-1] - np.pi    
     
     def reg2_dcproj(usph, magim, bfestim, niter=100, alpha_reg=0.05, alpha_dc=0.05):
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          # usph = usph+np.pi          
          ims = np.zeros((imsizer,imrizec,niter))
          grds_reg = np.zeros((imsizer,imrizec,niter))
          grds_dc = np.zeros((imsizer,imrizec,niter))
          ims[:,:,0]=usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())
          
          for ix in range(niter-1):
               
              grd_reg = reg2grd(ims[:,:,ix].flatten()).reshape([252,308])  # *alpha*np.real(1j*np.exp(-1j*ims[:,:,ix])*    fdivg(fgrad(np.exp(  1j* ims[:,:,ix]    )))     )
              grds_reg[:,:,ix]  = grd_reg
              grd_dc = reg2_dcgrd(ims[:,:,ix].flatten() , magim, bfestim).reshape([252,308])
              grds_dc[:,:,ix]  = grd_dc
              
              ims[:,:,ix+1] = ims[:,:,ix] + alpha_reg*grd_reg  - alpha_dc*grd_dc
              regval = reg2eval(ims[:,:,ix+1].flatten())
              f_dc = dconst(magim*np.exp(1j*ims[:,:,ix+1])*bfestim)
              
              print_info = False
              if print_info is True:
                  logging.info("norm grad reg: " + str(np.linalg.norm(grd_reg)))
                  logging.info("norm grad dc: " + str(np.linalg.norm(grd_dc)))
                  logging.info("regval: " + str(regval))
                  logging.info("fdc: (*1e9) {0:.6f}".format(f_dc/1e9))
          
          return ims[:,:,-1] #-np.pi    
     
     def reg2eval(im):
          # takes in 1d, returns scalar
          im = im.reshape([252,308])
          phs = np.exp(1j*im)
          return np.linalg.norm(fgrad(phs).flatten())
     
     def reg2grd(im):
          # takes in 1d, returns 1d
          im = im.reshape([252,308])
          return -2*np.real(1j*np.exp(-1j*im) * fdivg(fgrad(np.exp(1j * im)))).flatten()
     
     def reg2_dcgrd(phim, magim, bfestim):   
          # takes in 1d, returns 1d
          phim = phim.reshape([252,308])
          magim = magim.reshape([252,308])
          tmp = utils.UFT_with_sensmaps(bfestim * np.exp(1j * phim) * magim, uspat, sensmaps) - data
          return -2 * np.real(1j * np.exp(-1j*phim) * magim * bfestim * utils.tUFT_with_sensmaps(tmp, uspat, sensmaps)).flatten()
     
     def reg2_proj_ls(usph, niter=100):
          # from  Separate Magnitude and Phase Regularization via Compressed Sensing,  Feng Zhao
          # with line search         
          usph = usph + np.pi
          ims = np.zeros((imsizer, imrizec, niter))
          ims[:,:,0] = usph.copy()
          regval = reg2eval(ims[:,:,0].flatten())
          logging.info(regval)
          for ix in range(niter-1):               
              currgrd = reg2grd(ims[:,:,ix].flatten())     
              res = sop.minimize_scalar(lambda alpha: reg2eval(ims[:,:,ix].flatten() + alpha * currgrd   ), method='Golden')
              alphaopt = res.x
              logging.info("optimal alpha: " + str(alphaopt) )               
              ims[:,:,ix+1] = ims[:,:,ix] + alphaopt*currgrd.reshape([252,308])
              regval = reg2eval(ims[:,:,ix+1].flatten())
              logging.info("regval: " + str(regval))             
          return ims[:,:,-1]-np.pi 

     def N4corrf(im):
          phasetmp = np.angle(im)
          ddimcabs = np.abs(im)
          inputImage = sitk.GetImageFromArray(ddimcabs, isVector=False)
          corrector = sitk.N4BiasFieldCorrectionImageFilter();
          inputImage = sitk.Cast(inputImage, sitk.sitkFloat32)
          output = corrector.Execute(inputImage)
          N4biasfree_output = sitk.GetArrayFromImage(output)          
          n4biasfield = ddimcabs/(N4biasfree_output+1e-9)
          
          if np.isreal(im).all():
               return n4biasfield, N4biasfree_output 
          else:
               return n4biasfield, N4biasfree_output*np.exp(1j*phasetmp)
     
     # ===============================
     # make the data
     # ===============================     
     uspat = np.abs(us_ksp_r2) > 0
     uspat = uspat[:,:,0]
     data = us_ksp_r2
     trpat = np.zeros_like(uspat)
     trpat[:, 120:136] = 1
          
     logging.info(uspat)
     
     # ===================================== 
     # initialize counters
     # ===================================== 
     numiter = num_iter
     multip = 0 # 0.1
     alphas = stepsize*np.ones(numiter) # np.logspace(-4,-4,numiter)
     
     # ===================================== 
     # functions for POCS
     # ===================================== 
     def feval(im):
          return full_funceval(im)
     
     def geval(im):
          t1, t2, t3 = full_gradient(im)
          return np.reshape(t1,[-1]), np.reshape(t2,[-1]), np.reshape(t3,[-1])
     
     # =====================================
     # initialize data
     # =====================================
     recs = np.zeros((imsizer*imrizec, numiter+30), dtype=complex) 
     phaseregvals = []
     n4biasfields = []
     
     # =====================================
     # first image in the 'recs' list is the initial undersampled image
     # =====================================
     recs[:, 0] = utils.tUFT_with_sensmaps(data, uspat, sensmaps).flatten().copy() 
     
     if N4BFcorr:
          n4bf, N4bf_image = N4corrf( np.reshape(recs[:,0],[imsizer,imrizec]) )
          recs[:,0] = N4bf_image.flatten()
     else:
          n4bf = 1
               
     # =====================================
     # If the optimization is to be continued from a previous run,
     # load the final image from the previous optimization as the first image for the current optimization.
     # =====================================
     logging.info('contRec is ' + contRec)
     if contRec != '':
          try:
               logging.info('Reading from a previous pickle file ' + contRec)
               rr = pickle.load(open(contRec, 'rb'))
               recs[:, 0] = rr[:, -1]
               logging.info('Initialized to the previous recon from pickle: ' + contRec)
          except:
               logging.info('Reading from a previous numpy file ' + contRec)
               rr = np.load(contRec)
               recs[:, 0] = rr[:, -1]
               logging.info('Initialized to the previous recon from numpy: ' + contRec)
     
     # =====================================
     # main loop
     # we don't do GD, we do POCS (projection ontol convex sets)
     #     o. N1 times gradient updates for min_|phi| -ELBO(|x|)
     #     a. N2 times gradient updates for min_|x| -ELBO(|x|)
     #     b. do data consistency projection into a set of x such that || Ex - y ||_2^2 = 0.
     #     c. N gradient updates for the phase image ps: min_px
     # =====================================
     for it in range(0, numiter-1, 13):         
        
          alpha = alphas[it]
          
          # ===============================================
          # first do N1 times magnitude prior iterations wrt normalization module
          # ===============================================      
          recstmp = recs[:, it].copy()          
          for ix in range(n1):
          
               update_normalizer(recstmp,
                                 sess,
                                 norm_accum_grads_zero_op,
                                 norm_accum_grads_op,
                                 norm_accum_grads_mean_op,
                                 num_accum_steps_pl,
                                 norm_update_op)
                              
               ftot, f_lik, f_dc = feval(recstmp)
          
               if N4BFcorr:
                    f_dc = dconst(recstmp.reshape([imsizer, imrizec]) * n4bf)
               
               gtot, g_lik, g_dc = geval(recstmp) # gradient evaluation: total, wrt_vae, wrt_data_consistency
               
               logging.info("----------- updating normalization module, iteration number: " + str(it+ix))
               f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
               summary_writer.add_summary(f_summary_msg, it+ix)
               g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
               summary_writer.add_summary(g_summary_msg, it+ix)

               # recstmp will not change in this loop. Only the normalization module will.
               recs[:, it+ix+1] = recstmp.copy()
               
          # ===============================================
          # now do N2 times magnitude prior iterations wrt the image itself
          # ===============================================      
          recstmp = recs[:, it + n1].copy()          
          for ix in range(n2):
          
               ftot, f_lik, f_dc = feval(recstmp)
          
               if N4BFcorr:
                    f_dc = dconst(recstmp.reshape([imsizer, imrizec]) * n4bf)
               
               gtot, g_lik, g_dc = geval(recstmp) # gradient evaluation : total, wrt_vae, wrt_data_consistency
               
               logging.info("----------- updating image, iteration number: " + str(it+n1+ix))
               f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
               summary_writer.add_summary(f_summary_msg, it+n1+ix)
               g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
               summary_writer.add_summary(g_summary_msg, it+n1+ix)
     
               # recstmp will change in this loop. The normalization module will stay fixed.          
               recstmp = recstmp - alpha * g_lik # g_lik
               recs[:, it + ix + n1 + 1] = recstmp.copy()
     
          # =============================================== 
          # Now do a  DC projection.... ACTUALLY, skip the DC projection for now.
          # ===============================================
          logging.info("dummy step, iteration number: " + str(it+n1+n2+1))
          recs[:, it + n1 + n2 + 1] = recs[:, it + n1 + n2] 
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+1)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+1)
           
          # ===============================================
          # now do a phase projection
          # ===============================================
          tmpa = np.abs(np.reshape(recs[:, it + n1 + n2 + 1], [imsizer, imrizec]))
          tmpp = np.angle(np.reshape(recs[:, it + n1 + n2 + 1], [imsizer, imrizec]))
          tmpatv = tmpa.copy().flatten()
           
          if reglmb == 0:
               logging.info("skipping phase proj, iteration number: " + str(it+n1+n2+2))
               tmpptv = tmpp.copy().flatten()
               
          else:
               logging.info("doing phase proj, iteration number: " + str(it+n1+n2+2))
               if regtype == 'TV': # Total variation
                    tmpptv = tv_proj(tmpp, mu=0.125,lmb=reglmb,IT=regiter).flatten() # 0.1, 15

               elif regtype == 'reg2': # Tikhonov
                    tmpptv = reg2_proj(tmpp, alpha=reglmb, niter=100).flatten() # 0.1, 15
                    regval = reg2eval(tmpp)
                    phaseregvals.append(regval)
                    logging.info("KCT-dbg: phase reg value is " + str(regval))
                
               elif regtype == 'reg2_dc': # Tikhonov, with additional constraint from data consistency
                    tmpptv = reg2_dcproj(tmpp, tmpa, n4bf, alpha_reg=reglmb, alpha_dc=reglmb, niter=100).flatten()
                
               elif regtype == 'abs':
                    tmpptv=np.zeros_like(tmpp).flatten()
               
               elif regtype == 'reg2_ls':
                    tmpptv = reg2_proj_ls(tmpp, niter=regiter).flatten() #0.1, 15
                    regval = reg2eval(tmpp)
                    phaseregvals.append(regval)
                    logging.info("KCT-dbg: phase reg value is " + str(regval))
                
               else:
                    logging.info("hey mistake!!!!!!!!!!")
           
          # recombine magnitude and updated phase.
          recs[:, it + n1 + n2 + 2] = tmpatv*np.exp(1j*tmpptv)

          # add summary after phase projection step
          recstmp = recs[:, it + n1 + n2 + 2].copy()        
          ftot, f_lik, f_dc = feval(recstmp)
          gtot, g_lik, g_dc = geval(recstmp)
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+2)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+2)

          # ===============================================      
          # now do a data consistency projection
          # take the measured part of the k space from the measured data and the remaining part of the k space from the updated image.
          # ===============================================
          logging.info("doing data consistency step, iteration number: " + str(it+n1+n2+3))
          if not N4BFcorr:  
               tmp1 = utils.UFT_with_sensmaps(np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]), (1-uspat), sensmaps)
               tmp2 = utils.UFT_with_sensmaps(np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]), (uspat), sensmaps)
               tmp3 = data * uspat[:,:,np.newaxis]
               
               tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
               recs[:, it + n1 + n2 + 3] = utils.tFT_with_sensmaps(tmp, sensmaps).flatten()
               
               # ftot, f_lik, f_dc = feval(recs[:, it + n1 + n2 + 3])
               # logging.info('f_dc (1e6): ' + str(f_dc/1e6) + '  perc: ' + str(100*f_dc/np.linalg.norm(data)**2))
               
          elif N4BFcorr:               
               
               n4bf_prev = n4bf.copy()
               
               imgtmp = np.reshape(recs[:, it + n1 + n2 + 2], [imsizer,imrizec]) # biasfree
               
               imgtmp_bf = imgtmp * n4bf_prev # img with bf
               
               n4bf, N4bf_image = N4corrf(imgtmp_bf) # correct the bf, this correction is supposed to be better now.
               
               imgtmp_new = imgtmp * n4bf
               
               n4biasfields.append(n4bf)
               
               tmp1 = utils.UFT_with_sensmaps(imgtmp_new, (1-uspat), sensmaps)
               tmp3 = data * uspat[:, :, np.newaxis]
               tmp = tmp1 + (1 - multip) * tmp3 # multip=0 by default
               recs[:, it + n1 + n2 + 3] = (utils.tFT_with_sensmaps(tmp, sensmaps) / n4bf).flatten()
               
               # ftot, f_lik, f_dc = feval(recs[:, it + n1 + n2 + 3])
               # if N4BFcorr: f_dc = dconst(recs[:, it + n1 + n2 + 3].reshape([imsizer,imrizec])*n4bf)               
               # logging.info('f_dc (1e6): ' + str(f_dc/1e6) + '  perc: ' + str(100*f_dc / np.linalg.norm(data) ** 2))
               
          # add summary after data consistency step
          recstmp = recs[:, it + n1 + n2 + 3].copy()        
          ftot, f_lik, f_dc = feval(recstmp)
          gtot, g_lik, g_dc = geval(recstmp)
          f_summary_msg = sess.run(f_summary, feed_dict = {f_elbo: f_lik, f_data_consistency: f_dc/1e6, f_total: ftot})
          summary_writer.add_summary(f_summary_msg, it+n1+n2+3)
          g_summary_msg = sess.run(g_summary, feed_dict = {g_elbo: np.linalg.norm(g_lik), g_data_consistency: np.linalg.norm(g_dc), g_total: np.linalg.norm(g_lik) + np.linalg.norm(g_dc)})
          summary_writer.add_summary(g_summary_msg, it+n1+n2+3)
          summary_writer.flush()
              
     return recs, 0, phaseregvals, n4biasfields
Example #13
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode=False
#Match all private calls.
monitormodeprivate=False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher=Patcher("unwrapped.img")
    
    # bypass vocoder copy protection on S013.020
    patcher.nopout((0x8034a60))
    patcher.nopout((0x8034a60+0x2))
    patcher.nopout((0x8034a76))
    patcher.nopout((0x8034a76+0x2))
    patcher.nopout((0x8034a8c))
    patcher.nopout((0x8034a8c+0x2))
    patcher.nopout((0x8034aa2))
    patcher.nopout((0x8034aa2+0x2))
    patcher.nopout((0x8034ab8))
    patcher.nopout((0x8034ab8+0x2))
    patcher.nopout((0x8034ace))
    patcher.nopout((0x8034ace+0x2))
    patcher.nopout((0x8049f9a))
    patcher.nopout((0x8049f9a+0x2))
    patcher.nopout((0x804a820))
Example #14
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode=False
#Match all private calls.
monitormodeprivate=False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher=Patcher("unwrapped.img")
    
# # bypass vocoder copy protection on D013.020

#     patcher.nopout((0x08033f30+0x18))
#     patcher.nopout((0x08033f30+0x1a))

#     patcher.nopout((0x08033f30+0x2e))
#     patcher.nopout((0x08033f30+0x30))

#     patcher.nopout((0x08033f30+0x44))
#     patcher.nopout((0x08033f30+0x46))

#     patcher.nopout((0x08033f30+0x5a))
#     patcher.nopout((0x08033f30+0x5c))

#     patcher.nopout((0x08033f30+0x70))
#     patcher.nopout((0x08033f30+0x72))
Example #15
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

#Match all public calls.
monitormode=False;
#Match all private calls.
monitormodeprivate=False;

if __name__ == '__main__':
    print "Creating patches from unwrapped.img.";
    patcher=Patcher("unwrapped.img");
    

#     #These aren't quite enough to skip the Color Code check.  Not sure why.
    patcher.nopout(0x0803ea62,0xf040);  #Main CC check.
    patcher.nopout(0x0803ea64,0x80fd);
    patcher.nopout(0x0803e994,0xf040);  #Late Entry CC check.
    patcher.nopout(0x0803e996,0x8164);
    patcher.nopout(0x0803fd98);  #dmr_dll_parser CC check.
    patcher.nopout(0x0803fd9a);
    patcher.sethword(0x0803fd8e,0xe02d, #Check in dmr_dll_parser().
                     0xd02d);
    patcher.nopout(0x0803eafe,0xf100); #Disable CRC check, in case CC is included.
    patcher.nopout(0x0803eb00,0x80af);
    
        
    # Patches after here allow for an included applet.
    
Example #16
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    #Old logo patcher, no longer used.
    #fhello=open("welcome.txt","rb");
    #hello=fhello.read();
    #patcher.str2sprite(0x08094610,hello);
    #print patcher.sprite2str(0x08094610,0x14,760);

    #Old patch, matching on the first talkgroup.
    #We don't use this anymore, because the new patch is better.
    #patcher.nopout(0x0803ee36,0xd1ef);

    # New patch for monitoring all talk groups , matched on first
    # entry iff no other match.
    #wa mov r5, 0 @ 0x0803ee86 # So the radio thinks it matched at zero.
    patcher.sethword(0x0803ee86, 0x2500)
    #wa b 0x0803ee38 @ 0x0803ee88 # Branch back to perform that match.
    patcher.sethword(0x0803ee88, 0xe7d6)
    #Jump back to matched condition.
    #patcher.export("prom-public.img");

    #     #These aren't quite enough to skip the Color Code check.  Not sure why.
    #     patcher.nopout(0x0803ea62,0xf040);  #Main CC check.
Example #17
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode = False
#Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    # # bypass vocoder copy protection on D013.020

    #     patcher.nopout((0x08033f30+0x18))
    #     patcher.nopout((0x08033f30+0x1a))

    #     patcher.nopout((0x08033f30+0x2e))
    #     patcher.nopout((0x08033f30+0x30))

    #     patcher.nopout((0x08033f30+0x44))
    #     patcher.nopout((0x08033f30+0x46))

    #     patcher.nopout((0x08033f30+0x5a))
    #     patcher.nopout((0x08033f30+0x5c))

    #     patcher.nopout((0x08033f30+0x70))
    #     patcher.nopout((0x08033f30+0x72))
Example #18
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

#Match all public calls.
monitormode=False;
#Match all private calls.
monitormodeprivate=False;

if __name__ == '__main__':
    print "Creating patches from unwrapped.img.";
    patcher=Patcher("unwrapped.img");
    

#     #These aren't quite enough to skip the Color Code check.  Not sure why.
#     patcher.nopout(0x0803ea62,0xf040);  #Main CC check.
#     patcher.nopout(0x0803ea64,0x80fd);
#     patcher.nopout(0x0803e994,0xf040);  #Late Entry CC check.
#     patcher.nopout(0x0803e996,0x8164);
#     patcher.nopout(0x0803fd98);  #dmr_dll_parser CC check.
#     patcher.nopout(0x0803fd9a);
#     patcher.sethword(0x0803fd8e,0xe02d, #Check in dmr_dll_parser().
#                      0xd02d);
#     patcher.nopout(0x0803eafe,0xf100); #Disable CRC check, in case CC is included.
#     patcher.nopout(0x0803eb00,0x80af);
    
        
    # Patches after here allow for an included applet.
    
Example #19
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

#Match all public calls.
monitormode = False
#Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    #     #These aren't quite enough to skip the Color Code check.  Not sure why.
    patcher.nopout(0x0803ea62, 0xf040)
    #Main CC check.
    patcher.nopout(0x0803ea64, 0x80fd)
    patcher.nopout(0x0803e994, 0xf040)
    #Late Entry CC check.
    patcher.nopout(0x0803e996, 0x8164)
    patcher.nopout(0x0803fd98)
    #dmr_dll_parser CC check.
    patcher.nopout(0x0803fd9a)
    patcher.sethword(
        0x0803fd8e,
        0xe02d,  #Check in dmr_dll_parser().
        0xd02d)
    patcher.nopout(0x0803eafe, 0xf100)
    #Disable CRC check, in case CC is included.
Example #20
0
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

# Match all public calls.
monitormode = False
# Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print("Creating patches from unwrapped.img.")
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on S013.020
    patcher.nopout((0x8034a60))
    patcher.nopout((0x8034a60 + 0x2))
    patcher.nopout((0x8034a76))
    patcher.nopout((0x8034a76 + 0x2))
    patcher.nopout((0x8034a8c))
    patcher.nopout((0x8034a8c + 0x2))
    patcher.nopout((0x8034aa2))
    patcher.nopout((0x8034aa2 + 0x2))
    patcher.nopout((0x8034ab8))
    patcher.nopout((0x8034ab8 + 0x2))
    patcher.nopout((0x8034ace))
    patcher.nopout((0x8034ace + 0x2))
    patcher.nopout((0x8049f9a))
Example #21
0
import re
import subprocess
from Patcher import Patcher


def _get_output(cmd):
    proc = subprocess.Popen(cmd,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.STDOUT)
    (stdout, stderr) = proc.communicate()
    stdout = stdout.decode().rstrip("\r\n").lstrip("\r\n")
    return stdout


os.chdir("..")
patcher = Patcher("README.md", filetype="markdown")

stdout = _get_output(["./x509sak.py"])
text = "\n```\n$ ./x509sak.py\n%s\n```\n" % (stdout)
patcher.patch("summary", text)

commands = []
command_re = re.compile("Begin of cmd-(?P<cmdname>[a-z]+)")
for match in command_re.finditer(patcher.read()):
    cmdname = match.groupdict()["cmdname"]
    commands.append(cmdname)

for command in commands:
    stdout = _get_output(["./x509sak.py", command, "--help"])
    text = "\n```\n%s\n```\n" % (stdout)
    patcher.patch("cmd-%s" % (command), text)
Example #22
0
def vaerecon(us_ksp_r2,
             dcprojiter,
             onlydciter=0,
             lat_dim=60,
             patchsize=28,
             contRec='',
             lowresmodel=False,
             parfact=10,
             num_iter=1,
             regiter=15,
             reglmb=0.05,
             regtype='TV'):

    print('KCT-INFO: contRec is ' + contRec)
    print('KCT-INFO: parfact is ' + str(parfact))
    print("............", us_ksp_r2.shape)
    #print(sensmaps.shape)
    # set parameters
    #==============================================================================
    np.random.seed(seed=1)

    imsizer = us_ksp_r2.shape[0]  #252#256#252
    imrizec = us_ksp_r2.shape[1]  #308#256#308
    print(imsizer, imrizec)
    nsampl = 50

    #make a network and a patcher to use later
    #==============================================================================
    x_rec, x_inp, funop, grd0, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = definevae(
        lat_dim=lat_dim, patchsize=patchsize, batchsize=parfact * nsampl)

    Ptchr = Patcher(imsize=[imsizer, imrizec],
                    patchsize=patchsize,
                    step=int(patchsize / 2),
                    nopartials=True,
                    contatedges=True)
    nopatches = len(Ptchr.genpatchsizes)
    print("KCT-INFO: there will be in total " + str(nopatches) + " patches.")

    #define the necessary functions
    #==============================================================================

    def FT(x):
        # coil expansion followed by Fourier transform
        #inp: [nx, ny]
        #out: [nx, ny, ns]
        return np.fft.fftshift(np.fft.fft2(x[:, :, np.newaxis], axes=(0, 1)),
                               axes=(0, 1))

    def tFT(x):
        # inverse Fourier transform and coil combination
        #inp: [nx, ny, ns]
        #out: [nx, ny]

        temp = np.fft.ifft2(np.fft.ifftshift(x, axes=(0, 1)), axes=(0, 1))
        return np.sum(temp, axis=2)

    def UFT(x, uspat):
        # Encoding: undersampling +  FT
        #inp: [nx, ny], [nx, ny]
        #out: [nx, ny, ns]

        return uspat[:, :, np.newaxis] * FT(x)

    def tUFT(x, uspat):
        # transposed Encoding: inverse FT + undersampling
        #inp: [nx, ny], [nx, ny]
        #out: [nx, ny]

        tmp1 = uspat[:, :, np.newaxis]

        return tFT(tmp1 * x)

    def dconst(us):
        #inp: [nx, ny]
        #out: [nx, ny]

        return np.linalg.norm(UFT(us, uspat) - data)**2
        #np.linalg.norm ji suan l2 fanshu
    def dconst_grad(us):
        #inp: [nx, ny]
        #out: [nx, ny]
        return 2 * tUFT(UFT(us, uspat) - data, uspat)

    def prior(us):
        #inp: [parfact,ps*ps]
        #out: parfact

        us = np.abs(us)
        funeval = funop.eval(feed_dict={x_rec: np.tile(us, (nsampl, 1))})  #
        funeval = np.array(np.split(funeval, nsampl,
                                    axis=0))  # [nsampl x parfact x 1]
        return np.mean(funeval, axis=0).astype(np.float64)

    def prior_grad(us):
        #inp: [parfact, ps*ps]
        #out: [parfact, ps*ps]

        usc = us.copy()
        usabs = np.abs(us)

        grd0eval = grd0.eval(feed_dict={x_rec: np.tile(usabs, (nsampl, 1))
                                        })  # ,x_inp: np.tile(usabs,(nsampl,1))

        #grd0eval: [500x784]
        grd0eval = np.array(np.split(grd0eval, nsampl,
                                     axis=0))  # [nsampl x parfact x 784]
        grd0m = np.mean(grd0eval, axis=0)  #[parfact,784]
        print("11111111111111111111111111111111")
        print(np.abs(usc))
        print("22222222222222222222222222222222")
        print(grd0m)
        grd0m = usc / np.abs(usc) * grd0m

        return grd0m  #.astype(np.float64)

    def prior_grad_patches(ptchs):
        #inp: [np, ps, ps]
        #out: [np, ps, ps]
        #takes set of patches as input and returns a set of their grad.s
        #both grads are in the positive direction

        shape_orig = ptchs.shape

        ptchs = np.reshape(ptchs, [ptchs.shape[0], -1])

        grds = np.zeros([
            int(np.ceil(ptchs.shape[0] / parfact) * parfact),
            np.prod(ptchs.shape[1:])
        ],
                        dtype=np.complex64)

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, ((0, extraind), (0, 0)), mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            grds[parfact * ix:parfact * ix + parfact, :] = prior_grad(
                ptchs[parfact * ix:parfact * ix + parfact, :])

        grds = grds[0:shape_orig[0], :]

        return np.reshape(grds, shape_orig)
        #np.ceil  xiang shang quzhengshu

    def prior_patches(ptchs):
        #inp: [np, ps, ps]
        #out: 1

        fvls = np.zeros([int(np.ceil(ptchs.shape[0] / parfact) * parfact)])

        extraind = int(
            np.ceil(ptchs.shape[0] / parfact) * parfact) - ptchs.shape[0]
        ptchs = np.pad(ptchs, [(0, extraind), (0, 0), (0, 0)], mode='edge')

        for ix in range(int(np.ceil(ptchs.shape[0] / parfact))):
            fvls[parfact * ix:parfact * ix + parfact] = prior(
                np.reshape(ptchs[parfact * ix:parfact * ix + parfact, :, :],
                           [parfact, -1]))

        fvls = fvls[0:ptchs.shape[0]]

        return np.mean(fvls)

    def full_gradient(image):
        #inp: [nx*nx, 1]
        #out: [nx, ny], [nx, ny]

        #returns both gradients in the respective positive direction.
        #i.e. must

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        grd_prior = prior_grad_patches(ptchs)
        grd_prior = (-1) * Ptchr.patches2im(grd_prior)

        grd_dconst = dconst_grad(np.reshape(image, [imsizer, imrizec]))

        return grd_prior + grd_dconst, grd_prior, grd_dconst

    def full_funceval(image):
        #inp: [nx*nx, 1]
        #out: [1], [1], [1]

        tmpimg = np.reshape(image, [imsizer, imrizec])

        dc = dconst(tmpimg)

        ptchs = Ptchr.im2patches(np.reshape(image, [imsizer, imrizec]))
        ptchs = np.array(ptchs)

        priorval = (-1) * prior_patches(np.abs(ptchs))

        return priorval + dc, priorval, dc

    #define the phase regularization functions
    #==============================================================================

    def tv_proj(phs, mu=0.125, lmb=2, IT=225):
        # Total variation based projection

        phs = fb_tv_proj(phs, mu=mu, lmb=lmb, IT=IT)

        return phs

    def fgrad(im):
        # gradient operation with 1st order finite differences
        imr_x = np.roll(im, shift=-1, axis=0)
        imr_y = np.roll(im, shift=-1, axis=1)
        grd_x = imr_x - im
        grd_y = imr_y - im

        return np.array((grd_x, grd_y))

    def fdivg(im):
        # divergence operator with 1st order finite differences
        imr_x = np.roll(np.squeeze(im[0, :, :]), shift=1, axis=0)
        imr_y = np.roll(np.squeeze(im[1, :, :]), shift=1, axis=1)
        grd_x = np.squeeze(im[0, :, :]) - imr_x
        grd_y = np.squeeze(im[1, :, :]) - imr_y

        return grd_x + grd_y

    def f_st(u, lmb):
        # soft thresholding

        uabs = np.squeeze(np.sqrt(np.sum(u * np.conjugate(u), axis=0)))

        tmp = 1 - lmb / uabs
        tmp[np.abs(tmp) < 0] = 0

        uu = u * np.tile(tmp[np.newaxis, :, :], [u.shape[0], 1, 1])

        return uu

    def fb_tv_proj(im, u0=0, mu=0.125, lmb=1, IT=15):

        sz = im.shape
        us = np.zeros((2, sz[0], sz[1], IT))
        us[:, :, :, 0] = u0

        for it in range(IT - 1):

            #grad descent step:
            tmp1 = im - fdivg(us[:, :, :, it])
            tmp2 = mu * fgrad(tmp1)

            tmp3 = us[:, :, :, it] - tmp2

            #thresholding step:
            us[:, :, :, it + 1] = tmp3 - f_st(tmp3, lmb=lmb)

        #endfor

        return im - fdivg(us[:, :, :, it + 1])

    def reg2_proj(usph, niter=100, alpha=0.05):
        #A smoothness based based projection. Regularization method 2 from
        #"Separate Magnitude and Phase Regularization via Compressed Sensing",  Feng Zhao et al, IEEE TMI, 2012

        usph = usph + np.pi

        ims = np.zeros((imsizer, imrizec, niter))
        ims[:, :, 0] = usph.copy()
        for ix in range(niter - 1):
            ims[:, :, ix + 1] = ims[:, :, ix] - 2 * alpha * np.real(
                1j * np.exp(-1j * ims[:, :, ix]) *
                fdivg(fgrad(np.exp(1j * ims[:, :, ix]))))

        return ims[:, :, -1] - np.pi

    #make the data
    #===============================

    uspat = np.abs(us_ksp_r2) > 0
    uspat = uspat[:, :, 0]
    data = us_ksp_r2

    import pickle

    #make the functions for POCS
    #=====================================
    #number of iterations
    numiter = num_iter

    # if you want to do an affine data consistency projection
    multip = 0  # 0 means simply replacing the measured values

    # step size for the prior iterations
    alphas = np.logspace(-4, -4, numiter)

    # some funtions for simpler coding
    def feval(im):
        print('im.dtype..............', im.dtype)
        return full_funceval(im)

    def geval(im):
        t1, t2, t3 = full_gradient(im)
        return np.reshape(t1, [-1]), np.reshape(t2, [-1]), np.reshape(t3, [-1])

    # initialize data with the zero-filled image
    recs = np.zeros((imsizer * imrizec, numiter), dtype=complex)
    recs[:, 0] = tUFT(data, uspat).flatten().copy()

    # if you want to instead continue reconstruction from an existing image
    print(' KCT-INFO: contRec is ' + contRec)
    if contRec != '':
        print('KCT-INFO: reading from a previous file ' + contRec)
        rr = pickle.load(open(contRec, 'rb'))
        recs[:, 0] = rr[:, -1]
        print('KCT-INFO: initialized to the previous recon: ' + contRec)

    import time
    # the itertaion loop
    for it in range(numiter - 1):
        start = time.time()
        # get the step size for the Piteration
        alpha = alphas[it]

        # if you want to do some data consistency iterations before starting the prior projections
        # this can be helpful when doing recon with multiple coils, e.g. you can do pure SENSE in the beginning...
        # or only do phase projections in the beginning.
        if it > onlydciter:

            # get the gradients for the prior projection and the likelihood values
            ftot, f_prior, f_dc = feval(recs[:, it])
            gtot, g_prior, g_dc = geval(recs[:, it])

            #                print("it no: " + str(it) + " f_tot= " + str(ftot) + " f_prior= " + str(-f_prior) + " f_dc (x1e6)= " + str(f_dc/1e6) + " |g_prior|= " + str(np.linalg.norm(g_prior)) + " |g_dc|= " + str(np.linalg.norm(g_dc)) )
            print(
                "it no: {0}, f_tot= {1:.2f},f_prior= {2:.2f}, f_dc (x1e6)= {3:.2f}, |g_prior|= {4:.2f}, |g_dc|= {5:.2f}"
                .format(it, ftot, f_prior, f_dc / 1e6, np.linalg.norm(g_prior),
                        np.linalg.norm(g_dc)))

            # update the image with the prior gradient
            recs[:, it + 1] = recs[:, it] - alpha * g_prior
            print("............................................")
            print(recs[:, it + 1])
            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)

        else:  # the case where you do only data consistency iterations (also iteration 0)
            if not it == 0:
                print(
                    'KCT-info: skipping prior proj for the first onlydciters iter.s, doing only phase proj (then maybe DC proj as well) !!!'
                )

            recs[:, it + 1] = recs[:, it].copy()

            # seperate the magnitude from the phase and do the phase projection
            tmpa = np.abs(np.reshape(recs[:, it + 1], [imsizer, imrizec]))
            tmpp = np.angle(np.reshape(recs[:, it + 1], [imsizer, imrizec]))

            tmpaf = tmpa.copy().flatten()

            if reglmb == 0:
                print("KCT-info: skipping phase proj")
                tmpptv = tmpp.copy().flatten()
            else:
                if regtype == 'TV':
                    tmpptv = tv_proj(tmpp, mu=0.125, lmb=reglmb,
                                     IT=regiter).flatten()  #0.1, 15
                elif regtype == 'reg2':
                    tmpptv = reg2_proj(tmpp, alpha=reglmb,
                                       niter=regiter).flatten()  #0.1, 15
                else:
                    raise (TypeError)

            # combine back the phase and the magnitude
            recs[:, it + 1] = tmpaf * np.exp(1j * tmpptv)
            print(recs.shape)
        #do the DC projection every 'dcprojiter' iterations
        if it < onlydciter + 1 or it % dcprojiter == 0:  #

            tmp1 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                       (1 - uspat))
            tmp2 = UFT(np.reshape(recs[:, it + 1], [imsizer, imrizec]),
                       (uspat))
            tmp3 = data * uspat[:, :, np.newaxis]

            #combine the measured data with the projected image affinely (multip=0 for replacing the values)
            tmp = tmp1 + multip * tmp2 + (1 - multip) * tmp3
            recs[:, it + 1] = tFT(tmp).flatten()

            ftot, f_lik, f_dc = feval(recs[:, it + 1])
            print(ftot.shape, f_lik.shape, f_dc.shape)
        end = time.time()

        print('one iteration time is', end - start)
    return recs  #YourDatasetModuleHere
Example #23
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version D013.020

from Patcher import Patcher

#Match all public calls.
monitormode=False
#Match all private calls.
monitormodeprivate=False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher=Patcher("unwrapped.img")
    
# bypass vocoder copy protection on D013.020

    patcher.nopout((0x08033f30+0x18))
    patcher.nopout((0x08033f30+0x1a))

    patcher.nopout((0x08033f30+0x2e))
    patcher.nopout((0x08033f30+0x30))

    patcher.nopout((0x08033f30+0x44))
    patcher.nopout((0x08033f30+0x46))

    patcher.nopout((0x08033f30+0x5a))
    patcher.nopout((0x08033f30+0x5c))

    patcher.nopout((0x08033f30+0x70))
    patcher.nopout((0x08033f30+0x72))
Example #24
0
#!/usr/bin/env python2
# Vocoder Patch for MD380 Firmware
# Applies to version S013.020

from Patcher import Patcher

#Match all public calls.
monitormode = False
#Match all private calls.
monitormodeprivate = False

if __name__ == '__main__':
    print "Creating patches from unwrapped.img."
    patcher = Patcher("unwrapped.img")

    # bypass vocoder copy protection on S013.020
    patcher.nopout((0x8034a60))
    patcher.nopout((0x8034a60 + 0x2))
    patcher.nopout((0x8034a76))
    patcher.nopout((0x8034a76 + 0x2))
    patcher.nopout((0x8034a8c))
    patcher.nopout((0x8034a8c + 0x2))
    patcher.nopout((0x8034aa2))
    patcher.nopout((0x8034aa2 + 0x2))
    patcher.nopout((0x8034ab8))
    patcher.nopout((0x8034ab8 + 0x2))
    patcher.nopout((0x8034ace))
    patcher.nopout((0x8034ace + 0x2))
    patcher.nopout((0x8049f9a))
    patcher.nopout((0x8049f9a + 0x2))
    patcher.nopout((0x804a820))
Example #25
0
#!/usr/bin/env python2
# Promiscuous Mode Patch for MD380 Firmware
# Applies to version 2.032

from Patcher import Patcher

if __name__ == '__main__':
    print "Creating patches from unwrapped.img.";
    patcher=Patcher("unwrapped.img");
    
    #Old logo patcher, no longer used.
    #fhello=open("welcome.txt","rb");
    #hello=fhello.read();
    #patcher.str2sprite(0x08094610,hello);
    #print patcher.sprite2str(0x08094610,0x14,760);
    
    #Old patch, matching on the first talkgroup.
    #We don't use this anymore, because the new patch is better.
    #patcher.nopout(0x0803ee36,0xd1ef);
    
    # New patch for monitoring all talk groups , matched on first
    # entry iff no other match.
    #wa mov r5, 0 @ 0x0803ee86 # So the radio thinks it matched at zero.
    patcher.sethword(0x0803ee86, 0x2500);
    #wa b 0x0803ee38 @ 0x0803ee88 # Branch back to perform that match.
    patcher.sethword(0x0803ee88,0xe7d6); #Jump back to matched condition.
    #patcher.export("prom-public.img");
    
#     #These aren't quite enough to skip the Color Code check.  Not sure why.
#     patcher.nopout(0x0803ea62,0xf040);  #Main CC check.
#     patcher.nopout(0x0803ea64,0x80fd);
Example #26
0
for (option, enum_name, requires_parameter) in opts_long:
    if first:
        arg_enum.append("	%s = 1000," % (enum_name))
        first = False
    else:
        arg_enum.append("	%s," % (enum_name))
arg_enum.append("};")
arg_enum = "\n".join(arg_enum) + "\n"

cmd_def = ["	const char *short_options = \"%s\";" % (short_string)]
cmd_def.append("	struct option long_options[] = {")
for (option, enum_name, requires_parameter) in opts_long:
    param = "\"%s\"," % (option)
    if requires_parameter:
        cmd_def.append("		{ %-30s required_argument, 0, %s }," %
                       (param, enum_name))
    else:
        cmd_def.append("		{ %-30s no_argument,       0, %s }," %
                       (param, enum_name))
cmd_def.append("		{ 0 }")
cmd_def.append("	};")
cmd_def = "\n".join(cmd_def) + "\n"

patcher = Patcher("../pgmopts.c")
patcher.patch("help page", help_code)
patcher.patch("command definition enum", arg_enum)
patcher.patch("command definition", cmd_def)

patcher = Patcher("../README.md", filetype="markdown")
patcher.patch("help page", markdown_help_page)
Example #27
0
class Viewer:
    def __init__(self):
        self.index = None
        self.database = []
        self.setup()

    def setup(self):
        self.root = tk.Tk()
        self.root.title("labelview")
        self.root.state("iconic")
        self.root.protocol("WM_DELETE_WINDOW", self.on_close)
        # self.root.resizable(width=False, height=False)  # cannot change window size
        self.w = self.root.winfo_screenwidth()
        self.h = self.root.winfo_screenheight()
        
        # configure ttk style
        s = ttk.Style()
        s.configure("File.TButton", foreground="blue")
        s.configure("Flow.TButton", foreground="red")
        s.configure("Save.TButton", foreground="blue")

        self.tabs = ttk.Notebook(self.root)

        # thumbnail tab
        self.i = 0.8  # the fraction of image region, horizontal
        self.f = 0.4  # the fraction of file control region, vertical
        self.c = 0.4  # the fraction of checkbox region, vertical

        self.thumb_tab = ttk.Frame(self.tabs)
        self.tabs.add(self.thumb_tab, text="  thumbnail  ")

        # left side control panel
        self.control = ttk.Panedwindow(self.thumb_tab, orient="vertical")
        self.control.grid(row=0, column=0)

        # file flow control
        self.flowctl = ttk.Labelframe(self.control, text="file flow control", width=self.w*(1-self.i), height=self.h*self.f)
        self.control.add(self.flowctl)


        # open single file
        self.open_f = ttk.Button(self.flowctl, text="open file", command=self.load_file)
        self.open_f.grid(row=0, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # load single csv/xml corresponding to wsi file
        self.load_l = ttk.Button(self.flowctl, text="load csv/xml", command=self.load_labels)
        self.load_l.grid(row=0, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)

        # open directory
        self.open_d = ttk.Button(self.flowctl, text="open dir", style="File.TButton", command=self.load_files)
        self.open_d.grid(row=1, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # set csv/xml directory
        self.load_ld = ttk.Button(self.flowctl, text="load csv/xml dir", style="File.TButton", command=self.load_labels_dir)
        self.load_ld.grid(row=1, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)

        # current directory name
        self.dir_name_ = ttk.Label(self.flowctl, text="dir:")
        self.dir_name_.grid(row=2, column=0, columnspan=1, sticky="ew", ipady=5, padx=10, pady=10)
        self.dir_name = ttk.Label(self.flowctl, text="----")
        self.dir_name.grid(row=2, column=1, columnspan=1, sticky="ew", ipady=5, padx=10, pady=10)
        # display file count
        self.n_count_ = ttk.Label(self.flowctl, text="count:")
        self.n_count_.grid(row=2, column=2, columnspan=1, sticky="ew", ipady=5, padx=10, pady=10)
        self.n_count = ttk.Label(self.flowctl, text="----")
        self.n_count.grid(row=2, column=3, columnspan=1, sticky="ew", ipady=5, padx=10, pady=10)

        # display fname
        self.fname = ttk.Label(self.flowctl, text="XXXX.kfb/.tif")
        self.fname.grid(row=3, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # display lname
        self.lname = ttk.Label(self.flowctl, text="XXXX.csv/.xml")
        self.lname.grid(row=3, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)

        # previous
        self.prev_b = ttk.Button(self.flowctl, text="previous", style="Flow.TButton", command=lambda: self.update(step=-1))
        self.prev_b.grid(row=4, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # next
        self.next_b = ttk.Button(self.flowctl, text="next", style="Flow.TButton", command=lambda: self.update(step=1))
        self.next_b.grid(row=4, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)

        

        # checkbox control panel
        self.colorctl = ttk.Labelframe(self.control, text="choose classes", width=self.w*(1-self.i), height=self.h*self.c)
        self.control.add(self.colorctl)

        # add blur thumbnail image button
        self.blur = ttk.Button(self.colorctl, text="blur", command=lambda: self.update(blur=True))
        self.blur.grid(row=0, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=5)
        # add confirm button
        self.conform = ttk.Button(self.colorctl, text="confirm", command=lambda: self.update())
        self.conform.grid(row=0, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=5)
        # add checkboxes
        self.colorblock = []
        self.checkboxes = []
        for i, class_i in enumerate(cfg.CLASSES):
            var = IntVar(value=1)
            clb = ttk.Label(self.colorctl, text="--", background=cfg.COLOURS[class_i])
            chk = ttk.Checkbutton(self.colorctl, text=class_i, variable=var)
            if i < math.ceil(len(cfg.CLASSES)/2):
                clb.grid(row=1+i, column=0, columnspan=1, sticky="ew", ipady=1, padx=10, pady=2)
                chk.grid(row=1+i, column=1, columnspan=1, sticky="w", ipady=1, padx=10, pady=2)
            else:
                clb.grid(row=1+i-math.ceil(len(cfg.CLASSES)/2), column=2, columnspan=1, sticky="ew", ipady=1, padx=10, pady=2)
                chk.grid(row=1+i-math.ceil(len(cfg.CLASSES)/2), column=3, columnspan=1, sticky="w", ipady=1, padx=10, pady=2)
            self.colorblock.append(clb)
            self.checkboxes.append(var)
        
        
        # separator
        self.separator = ttk.Separator(self.thumb_tab, orient="vertical")
        self.separator.grid(row=0, column=1, sticky="ns")

        # right side image panel
        self.display = tk.Canvas(self.thumb_tab, width=self.w*self.i, height=self.h)
        self.display.grid(row=0, column=2)



        # labeled images tab
        self.w_left_i = 256  # the size of left size panel, horizontal
        self.s_i = 0.2  # the fraction of file info display and label file save control, vertical
        self.i_i = 0.2  # the fraction of images flow control panel, vertical
        self.c_i = 0.4  # the fraction of checkbox control panel, vertical

        self.image_tab = ttk.Frame(self.tabs)
        self.image_tab.bind("<Visibility>", self.on_visibility)  # clean and update contents when switch to this tab
        self.tabs.add(self.image_tab, text=" label images ")

        # left side control panel
        self.control_i = ttk.Panedwindow(self.image_tab, orient="vertical")
        self.control_i.grid(row=0, column=0)


        # file info display and label file save control
        self.savectl_i = ttk.Labelframe(self.control_i, text="label file write control", width=self.w_left_i, height=self.h*self.s_i)
        self.control_i.add(self.savectl_i)

        # set label file save dir
        self.label_dir = ttk.Button(self.savectl_i, text="set label file save dir", command=self.set_save_dir)
        self.label_dir.grid(row=0, column=0, columnspan=4, sticky="ew", ipady=5, padx=10, pady=10)
        # display file progress
        self.n_count_i = ttk.Label(self.savectl_i, text="file progress")
        self.n_count_i.grid(row=1, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # save changes
        self.save_changes = ttk.Button(self.savectl_i, text="save changes", style="Save.TButton", command=self.save_labels)
        self.save_changes.grid(row=1, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)


        # images flow control
        self.flowctl_i = ttk.Labelframe(self.control_i, text="images flow control", width=self.w_left_i, height=self.h*self.i_i)
        self.control_i.add(self.flowctl_i)

        # previous
        self.prev_b_i = ttk.Button(self.flowctl_i, text="previous batch", style="Flow.TButton", command=lambda: self.update_i(step=-1))
        self.prev_b_i.grid(row=0, column=0, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)
        # next
        self.next_b_i = ttk.Button(self.flowctl_i, text="next batch", style="Flow.TButton", command=lambda: self.update_i(step=1))
        self.next_b_i.grid(row=0, column=2, columnspan=2, sticky="ew", ipady=5, padx=10, pady=10)


        # checkbox control
        self.colorctl_i = ttk.Labelframe(self.control_i, text="choose classes", width=self.w_left_i, height=self.h*self.c_i)
        self.control_i.add(self.colorctl_i)

        # images view progress
        self.image_pro = ttk.Label(self.colorctl_i, text="images view progress")
        self.image_pro.grid(row=0, column=0, columnspan=2, sticky="ew", ipady=5, padx=5, pady=10)
        # confirm button
        self.confirm_i = ttk.Button(self.colorctl_i, text="confirm", command=lambda: self.update_i(step=0))
        self.confirm_i.grid(row=0, column=2, columnspan=2, sticky="ew", ipady=5, padx=5, pady=10)
        # set display modes: the number of images in a row
        self.M_ = ttk.Label(self.colorctl_i, text="# M:")
        self.M_.grid(row=1, column=0, columnspan=1, sticky="w", ipady=1, padx=5, pady=5)
        self.M = ttk.Entry(self.colorctl_i)
        self.M.insert(0, "3")
        self.M.config(width=8)
        self.M.grid(row=1, column=1, columnspan=1, sticky="w", ipady=1, padx=5, pady=5)
        # set image size: the times of image size over label box
        self.N_ = ttk.Label(self.colorctl_i, text="# N:")
        self.N_.grid(row=1, column=2, columnspan=1, sticky="w", ipady=1, padx=5, pady=5)
        self.N = ttk.Entry(self.colorctl_i)
        self.N.insert(0, "2")
        self.N.config(width=8)
        self.N.grid(row=1, column=3, columnspan=1, sticky="w", ipady=1, padx=5, pady=5)
        # add checkboxes
        self.colorblock_i = []
        self.checkboxes_i = []
        for i,class_i in enumerate(cfg.CLASSES):
            var = IntVar(value=0)
            clb = ttk.Label(self.colorctl_i, text="--", background='#ffffff')
            chk = ttk.Checkbutton(self.colorctl_i, text=class_i, variable=var)
            if i < math.ceil(len(cfg.CLASSES)/2):
                clb.grid(row=2+i, column=0, columnspan=1, sticky="ew", ipady=1, padx=5, pady=2)
                chk.grid(row=2+i, column=1, columnspan=1, sticky="w", ipady=1, padx=5, pady=2)
            else:
                clb.grid(row=2+i-math.ceil(len(cfg.CLASSES)/2), column=2, columnspan=1, sticky="ew", ipady=1, padx=5, pady=2)
                chk.grid(row=2+i-math.ceil(len(cfg.CLASSES)/2), column=3, columnspan=1, sticky="w", ipady=1, padx=5, pady=2)
            self.colorblock_i.append(clb)
            self.checkboxes_i.append(var)


        # thumbnail display for comparison purpose
        self.thumbmini = tk.Canvas(self.control_i, width=self.w_left_i, height=self.w_left_i)
        self.control_i.add(self.thumbmini)


        # separator
        self.separator_i = ttk.Separator(self.image_tab, orient="vertical")
        self.separator_i.grid(row=0, column=1, sticky="ns")

        # right side image panel
        self.display_i = tk.Canvas(self.image_tab, width=self.w-self.w_left_i, height=self.h)
        # self.display_i.bind("<d>", self.on_d_pressed)
        # self.display_i.focus_set()
        self.display_i.grid(row=0, column=2)


        # third tab: logging
        self.log_tab = ttk.Frame(self.tabs)
        self.logger = Logger(self.log_tab)
        self.tabs.add(self.log_tab, text="     log     ", padding=5)

        self.tabs.pack()


    def load_file(self):
        """ open wsi file dialog """
        fname = filedialog.askopenfilename(filetypes=(("*.kfb files", "*.kfb"), ("*.tif files", "*.tif")))
        if not fname:
            messagebox.showinfo("warning", "no file choosed")
        else:
            self.index = 0
            del self.database
            self.database = []
            self.database.append({"basename":os.path.splitext(os.path.basename(fname))[0], "fname":fname, "lname":None})
            self.update()
            self.logger.log_file("loaded file " + fname)


    def load_files(self):
        """ open wsi file directory dialog """
        file_dir = filedialog.askdirectory()
        if not file_dir:
            messagebox.showinfo("warning", "no directory choosed")
        else:
            self.index = None
            del self.database
            self.database = []
            fnames = os.listdir(file_dir)
            fnames.sort()
            for fname in fnames:
                if fname.endswith(".kfb") or fname.endswith(".tif"):
                    self.database.append({"basename":os.path.splitext(fname)[0], "fname":os.path.join(file_dir, fname), "lname":None})
            if not self.database:
                messagebox.showinfo("warning", "no kfb file exists")
            else:
                self.index = 0
                self.update()
                self.logger.log_file("loaded files from " + file_dir)


    def load_labels(self):
        """ open label file dialog """
        if self.index is None:
            messagebox.showinfo("error", "no kfb/tif file loaded")
            return
        lname = filedialog.askopenfilename(filetypes=(("csv files", "*.csv"), ("xml files", "*.xml")))
        if not lname:
            messagebox.showinfo("warning", "no file choosed")
        elif not self.database[self.index]["basename"] in os.path.basename(lname):
            messagebox.showinfo("warning", "label file does not match with kfb/tif file")
        else:
            self.database[self.index]["lname"] = lname
            self.update()
            self.logger.set_log_path(os.path.dirname(lname))
            self.logger.log_info("loaded label file " + lname)


    def load_labels_dir(self):
        """ open label file directory dialog """
        def nullify_lname():
            """ set lable name to None before loading new """
            for item in self.database:
                item["lname"] = None
        def choose_matched():
            """ choose only those wsi file name matches label file """
            database_new = [item for item in self.database if item["lname"] is not None]
            del self.database
            self.database = database_new

        if self.index is None:
            messagebox.showinfo("error", "no kfb/tif file loaded")
            return
        file_dir = filedialog.askdirectory()
        if not file_dir:
            messagebox.showinfo("warning", "no directory choosed")
        else:
            nullify_lname()
            lnames = os.listdir(file_dir)
            lnames.sort()
            for lname in lnames:
                if lname.endswith(".csv") or lname.endswith(".xml"):
                    for i,item in enumerate(self.database):
                        if item["basename"] in os.path.basename(lname):
                            self.database[i]["lname"] = os.path.join(file_dir, lname)
                            break
            choose_matched()
            self.index = 0 if self.database else None
            self.update()
            self.logger.set_log_path(file_dir)
            self.logger.log_info("loaded label files from " + file_dir)


    def update_patcher(self):
        """ initialize and update patcher """
        def need_reget_label():
            """ check if need to get label from patcher """
            # check if haven't got labels from patcher
            if not "labels" in self.database[self.index]:
                return True
            # check if have got labels from patcher, but with None label file
            # note: there is a case that this will lead to excess label getting,
            # that is, patcher.labels contains no labels.
            count = 0
            for class_i,boxes in self.database[self.index]["labels"].items():
                count += len(boxes)
            return count == 0

        # in case when patcher initialized without label file, need to reinitialize
        if need_reget_label():
            self.patcher = Patcher(self.database[self.index]["fname"], self.database[self.index]["lname"])
            self.database[self.index]["labels"] = self.patcher.get_labels()


    def load_thumbnail(self, blur):
        """ load thumbnail image, given selected classes """
        def resize(image, w, h):
            w0, h0 = image.size
            scale = min(w/w0, h/w0)
            return image.resize((int(w0*scale), int(h0*scale)))
        
        checked_classes = [cfg.CLASSES[i] for i,var in enumerate(self.checkboxes) if var.get()]
        image = self.patcher.patch_label({key:value for key,value in self.database[self.index]["labels"].items() if key in checked_classes}, blur)
        image = ImageTk.PhotoImage(resize(image, self.w*self.i, self.h))
        self.thumb_on = image  # stores the thumbnail image on first tab
        self.display.create_image(self.w*self.i/2, self.h/2, image=image)


    def update_text(self):
        """ tab1: update display text """
        self.dir_name.config(text=os.path.basename(os.path.dirname(self.database[self.index]["fname"])))
        self.n_count.config(text="{} / {}".format(self.index+1, len(self.database)))
        self.fname.config(text=os.path.basename(self.database[self.index]["fname"]))
        if self.database[self.index]["lname"] is not None:
            self.lname.config(text=os.path.basename(self.database[self.index]["lname"]))
        else:
            self.lname.config(text="--------")


    def update_label_counts(self):
        """ tab1: update label counts """
        for i,class_i in enumerate(cfg.CLASSES):
            self.colorblock[i].config(text=str(len(self.database[self.index]["labels"][class_i])))


    def clear(self):
        """ clear and update window when failed to load matching label files """
        self.dir_name.config(text="----")
        self.n_count.config(text="----")
        self.fname.config(text="----.kfb/.tif")
        self.lname.config(text="----.csv/.xml")
        self.thumb_on = None
        self.database = []
        for clb in self.colorblock:
            clb.config(text="--")


    def update(self, step=0, blur=False):
        """ update method of first tab """
        if self.index is None:
            messagebox.showinfo("error", "there is no file/label matched")
            self.clear()
            return
        if self.index+step not in range(len(self.database)):
            messagebox.showinfo("warning", "already the end")
            return
        self.index += step
        self.update_patcher()
        self.load_thumbnail(blur)
        self.update_text()
        self.update_label_counts()


    # below are functions relative to second tab
    def set_save_dir(self):
        """ set new label file saving directory, use source label file dir if not set """
        if self.index is None:
            messagebox.showinfo("error", "there is no file/label matched")
            return
        self.save_dir = filedialog.askdirectory()
        if not self.save_dir:
            messagebox.showinfo("warning", "no directory choosed, will use default")
        else:
            self.logger.log_info("set label file saving directory " + self.save_dir)


    def load_images(self):
        # update color
        self.update_color_i(finished=False)
        # get checked classes
        checked_classes = [cfg.CLASSES[i] for i,var in enumerate(self.checkboxes_i) if var.get()]
        # get images from patcher
        sub_labels = {key:value for key,value in self.database[self.index]["labels"].items() if key in checked_classes}
        self.image_list = self.patcher.crop_images(sub_labels, N=float(self.N.get()))
        self.cursor = 0  # stores the index of first image on canvas of second tab
        self.images_on = None  # stores the labeled images on second tab
        self.logger.log_info("selected classes: {}".format(checked_classes))

        # calculate anchor points for images
        M = int(self.M.get())  # number of images in a row
        self.size_avg = self.w * self.i / M  # average image size to display
        rows = int(self.h / self.size_avg)  # number of rows to display
        pad = (self.h-self.size_avg*rows)/rows  # padding between rows
        self.anchors = []  # stores the center position of images to show
        for row in range(rows):
            for i in range(M):
                self.anchors.append((int(self.size_avg/2 + self.size_avg*i), int(self.size_avg/2 + self.size_avg*row + pad*row)))


    def get_cursor_of_image(self, position):
        """ get current index of image in image_list, given position (event.x, event.y) """
        distance = {}
        cursor = self.cursor
        for anchor in self.anchors:
            if cursor not in range(len(self.image_list)):
                break
            distance[(position[0]-anchor[0])**2 + (position[1]-anchor[1])**2] = cursor
            cursor += 1
        sort_dist = sorted(distance.items())
        return sort_dist[0][1] if sort_dist else -1  # bug in canvas: still able to click when there are no images


    def get_label_by_cursor(self, cursor_of_image):
        """ 1. get key box from image_list, by index
            2. get label from labels, by key box
        """
        box = self.image_list[cursor_of_image][1]
        # search self.database[self.index]["labels"] for class_i
        for class_i,boxes in self.database[self.index]["labels"].items():
            if box in boxes:
                return class_i
        return "DELETED!!!"


    def on_single_click(self, event):
        """ on single click, display label of current image """
        def show_label():
            # destroy any toplevel window
            destroy()
            # creates a toplevel window
            self.tw = tk.Toplevel(self.display_i)
            # Leaves only the frame and removes the app window
            self.tw.wm_overrideredirect(True)
            win = tk.Frame(self.tw, borderwidth=0)
            lbl = ttk.Label(win, text=label, justify=tk.LEFT,
                            relief=tk.SOLID, borderwidth=0)
            pad = (5, 3, 5, 3)
            lbl.grid(padx=(pad[0], pad[2]), pady=(pad[1], pad[3]), sticky=tk.NSEW)
            win.grid()
            # set the position of label to show on screen
            x_of_c, y_of_c = self.display_i.winfo_rootx(), self.display_i.winfo_rooty()
            x, y = x_can + x_of_c + 5, y_can + y_of_c - 25  # x_can,y_can are defined outside
            self.tw.wm_geometry("+%d+%d" % (x, y))

        def destroy():
            if hasattr(self, "tw") and self.tw:
                self.tw.destroy()
            self.tw = None

        def wait_and_hide():
            waittime = 500  # in miniseconds
            self.display_i.after(waittime, destroy)

        # get mouse position, relative to canvas
        x_can, y_can = event.x, event.y
        # get the index of image in self.image_list
        cursor_of_image = self.get_cursor_of_image((x_can, y_can))
        if cursor_of_image == -1:
            return
        # get the lable of the image
        label = self.get_label_by_cursor(cursor_of_image)
        # show label on screen
        show_label()
        # hide label after some time
        wait_and_hide()


    def on_double_click(self, event):
        """ on double click, popup a new image window, being able change image cropping size """
        def show_image():
            # destroy existing toplevel window
            destroy()
            # create a toplevel window and configure window size
            self.tw_i = tk.Toplevel(self.display_i)
            w, h = self.image_list[cursor_of_image][2].size
            scale = min(1.0, min(self.w/w, self.h/h))
            w, h = int(w*scale), int(h*scale)
            self.tw_i.wm_geometry("{}x{}+{}+{}".format(w, h, int(self.w/2-w/2), int(self.h/2-h/2)))
            self.tw_i_x = float(self.N.get())  # stores the times of image dimension over cell dimension
            self.tw_i.title(label+" {}x".format(self.tw_i_x))
            # add decrease and increase menu 
            menubar = tk.Menu(self.tw_i)
            sizemenu = tk.Menu(menubar, tearoff=0)
            sizemenu.add_command(label="decrease", command=lambda: resize(delta=-1))
            sizemenu.add_separator()
            sizemenu.add_command(label="increase", command=lambda: resize(delta=1))
            menubar.add_cascade(label="change size", menu=sizemenu)
            self.tw_i.config(menu=menubar)
            # add image to canvas
            self.tw_i_can = tk.Canvas(self.tw_i)
            self.tw_i_can.pack(fill=tk.BOTH, expand=tk.YES)
            self.image_tw = ImageTk.PhotoImage(self.image_list[cursor_of_image][2].resize((w, h)))
            self.tw_i_can.create_image(w//2, h//2, image=self.image_tw)

        def resize(delta):
            self.tw_i_x += delta
            if self.tw_i_x < 1.0:
                # already down to the cell, cannot cut smaller
                self.tw_i_x -= delta
                return
            image = self.patcher.get_cell_by_N(box=self.image_list[cursor_of_image][1], N=self.tw_i_x)
            w, h = image.size
            scale = min(1.0, min(self.w/w, self.h/h))
            w, h = int(w*scale), int(h*scale)
            self.tw_i.title(label+" {}x".format(self.tw_i_x))
            self.tw_i.wm_geometry("{}x{}".format(w, h))
            self.image_tw = ImageTk.PhotoImage(image.resize((w, h)))
            self.tw_i_can.delete("all")
            self.tw_i_can.config(width=w, height=h)
            self.tw_i_can.create_image(w//2, h//2, image=self.image_tw)

        def destroy():
            if hasattr(self, "tw_i") and self.tw_i:
                self.tw_i.destroy()
            self.tw_i = None

        # get mouse position, relative to canvas
        x_can, y_can = event.x, event.y
        # get the index of image in self.image_list
        cursor_of_image = self.get_cursor_of_image((x_can, y_can))
        if cursor_of_image == -1:
            return
        # get the lable of the image
        label = self.get_label_by_cursor(cursor_of_image)
        # display image
        show_image()


    def on_right_click(self, event):
        """ on double click, popup radiobox selection panel, for label change or delete """
        def show_choices():
            # destroy existing toplevel window
            destroy()
            # creates a toplevel window
            self.tw = tk.Toplevel(self.display_i)
            # Leaves only the frame and removes the app window
            self.tw.wm_overrideredirect(True)
            win = tk.Frame(self.tw, borderwidth=0)
            self.radioBox = StringVar(value=label)
            # delete choice
            tk.Radiobutton(win, text="DELETE", padx=5, variable=self.radioBox, 
                                command=choice_made, 
                                value="DELETE", fg="#ff0000", activeforeground="#ff0000").pack(anchor=tk.W)
            # choices to be change label to
            for class_i in cfg.CLASSES:
                tk.Radiobutton(win, text=class_i, padx=5, variable=self.radioBox,
                                    command=choice_made, 
                                    value=class_i).pack(anchor=tk.W)
            win.grid()
            # set the position of label to show on screen
            # first get the position of canvas (upleft) on screen
            x_of_c, y_of_c = self.display_i.winfo_rootx(), self.display_i.winfo_rooty()
            # then add the position of mouse on canvas 
            x, y = x_can + x_of_c + 5, y_can + y_of_c - 150
            # adjust x,y if tw falls out of window
            x = min(x, self.w-100)
            y = min(y, self.h-400)
            y = max(y, 10)
            self.tw.wm_geometry("+%d+%d" % (x, y))  

        def destroy():
            if hasattr(self, "tw") and self.tw:
                self.tw.destroy()
            self.tw = None

        def choice_made():
            waittime = 300  # in miniseconds
            self.display_i.after(waittime, destroy)

            # retrieve choice and make changes accordingly
            choice = self.radioBox.get()
            # detect if there are changes made
            if choice == label or (label == "DELETED!!!" and choice == "DELETE"):
                return
            box = self.image_list[cursor_of_image][1]           
            # change original label to choice
            if choice != "DELETE":
                self.database[self.index]["labels"][choice][box] = self.database[self.index]["labels"][label][box]
                outline_color = "blue"
                self.image_list[cursor_of_image][0] = choice  # update class_i in image_list
                self.logger.log_change("{} {}".format(box,label), "{}".format(choice))
            else:  # when choice == "DELETE"
                if "DELETED!!!" not in self.database[self.index]["labels"]:  # add new key to labels
                    self.database[self.index]["labels"]["DELETED!!!"] = {}
                self.database[self.index]["labels"]["DELETED!!!"][box] = self.database[self.index]["labels"][label][box]
                outline_color = "red"
                self.image_list[cursor_of_image][0] = "DELETED!!!"  # update class_i in image_list, note: new key here
                self.logger.log_delete("{} {}".format(box,label))
            # delete label
            del self.database[self.index]["labels"][label][box]
            # update label counts
            self.update_label_counts_i()
            # add surrounding rectangle to changed image on canvas
            self.image_list[cursor_of_image].append(outline_color)
            anchor = self.anchors[cursor_of_image - self.cursor]
            x0, y0 = anchor[0] - self.size_avg/2 + 1, anchor[1] - self.size_avg/2 + 1
            x1, y1 = anchor[0] + self.size_avg/2 - 1, anchor[1] + self.size_avg/2 - 1
            self.display_i.create_rectangle(x0, y0, x1, y1, outline=outline_color, width=2)
                
        # get mouse position, relative to canvas
        x_can, y_can = event.x, event.y
        # get the index of image in self.image_list
        cursor_of_image = self.get_cursor_of_image((x_can, y_can))
        if cursor_of_image == -1:
            return
        # get the lable of the image
        label = self.get_label_by_cursor(cursor_of_image)
        # # check if label has been deleted
        # if label == "DELETED!!!":
        #     messagebox.showinfo("warning", "image has been deleted")
        # show label choose dialog on screen
        show_choices()
    

    """ key press doesn't work after canvas loading images
    def on_d_pressed(self, event):
        print("d pressed at {}, ({},{})".format(self.display_i.type(tk.CURRENT), event.x, event.y))
        if self.display_i.type(tk.CURRENT) != "image":
            return
        # get mouse position, relative to canvas
        x_can, y_can = event.x, event.y
        # get the index of image in self.image_list
        cursor_of_image = self.get_cursor_of_image((x_can, y_can))
        # get the label of the image
        label = self.get_label_by_cursor(cursor_of_image)
        if label == "DELETED!!!":
            return
        # delete label
        box = self.image_list[cursor_of_image][1] 
        if "DELETED!!!" not in self.database[self.index]["labels"]:
            self.database[self.index]["labels"]["DELETED!!!"] = {}
        self.database[self.index]["labels"]["DELETED!!!"][box] = self.database[self.index]["labels"][label][box]
        outline_color = "red"
        self.image_list[cursor_of_image][0] = "DELETED!!!"  # update class_i in image_list, note: new key here
        self.logger.log_delete("{} {}".format(box,label))
        del self.database[self.index]["labels"][label][box]
        # update label counts
        self.update_label_counts_i()
        # add surrounding rectangle to changed image on canvas
        self.image_list[cursor_of_image].append(outline_color)
        anchor = self.anchors[cursor_of_image - self.cursor]
        x0, y0 = anchor[0] - self.size_avg/2 + 1, anchor[1] - self.size_avg/2 + 1
        x1, y1 = anchor[0] + self.size_avg/2 - 1, anchor[1] + self.size_avg/2 - 1
        self.display_i.create_rectangle(x0, y0, x1, y1, outline=outline_color, width=2)
    """


    def update_images(self, step):
        """ update display images on canvas """
        def resize(image, size):
            w, h = image.size
            scale = min(1, min(size/w, size/h))
            return image.resize((int(w*scale), int(h*scale)))

        # update color, when last batch of images are displayed
        if self.cursor + 2*len(self.anchors) >= len(self.image_list):
            self.update_color_i(finished=True)

        # update cursor, stay unchanged if running to the end
        self.cursor += step * len(self.anchors)
        if self.cursor not in range(len(self.image_list)):
            self.cursor -= step * len(self.anchors)
            messagebox.showinfo("warning", "no more images")
            return

        # update images
        del self.images_on
        self.images_on = []
        self.display_i.delete("all")  # delete all objects before loading
        cursor = self.cursor
        for anchor in self.anchors:
            if cursor not in range(len(self.image_list)):
                break
            image = ImageTk.PhotoImage(resize(self.image_list[cursor][2], self.size_avg))
            self.images_on.append(image)
            self.display_i.create_image(anchor[0], anchor[1], image=image, tags="image_{}".format(cursor))
            self.display_i.tag_bind("image_{}".format(cursor), "<ButtonPress-1>", self.on_single_click)
            self.display_i.tag_bind("image_{}".format(cursor), "<Double-Button-1>", self.on_double_click)
            self.display_i.tag_bind("image_{}".format(cursor), "<Button-3>", self.on_right_click)
            # add bounding rectangle for changed images, on canvas
            if len(self.image_list[cursor]) > 3:
                outline_color = self.image_list[cursor][-1]  # get the last color, should i only store one instead?
                x0, y0 = anchor[0] - self.size_avg/2 + 1, anchor[1] - self.size_avg/2 + 1
                x1, y1 = anchor[0] + self.size_avg/2 - 1, anchor[1] + self.size_avg/2 - 1
                self.display_i.create_rectangle(x0, y0, x1, y1, outline=outline_color, width=2)                
            cursor += 1


    def load_thumbnailmini(self):
        """ load mini thumbnail image, in bottom left corner """
        def resize(image, w, h):
            w0, h0 = image.size
            scale = min(w/w0, h/w0)
            return image.resize((int(w0*scale), int(h0*scale)))

        sub_labels = {}
        for i in range(len(self.anchors)):
            cursor = self.cursor + i
            if cursor not in range(len(self.image_list)):
                break
            class_i = self.image_list[cursor][0]
            if class_i == "DELETED!!!":  # omit deleted labels
                continue
            box = self.image_list[cursor][1]
            if class_i not in sub_labels:
                sub_labels[class_i] = {}
            sub_labels[class_i][box] = self.database[self.index]["labels"][class_i][box]
        
        image = self.patcher.patch_label(sub_labels, blur=True)
        image = ImageTk.PhotoImage(resize(image, self.w_left_i, self.w_left_i))
        self.thumbmini_on = image  # stores the mini thumbnail image on second tab
        self.thumbmini.create_image(self.w_left_i/2, self.w_left_i/2, image=image)


    def update_label_counts_i(self, clear=False):
        """ update label counts in second tab """
        if clear:
            for clb in self.colorblock_i:
                clb.config(text="--")
        else:
            for i,clb in enumerate(self.colorblock_i):
                clb.config(text=str(len(self.database[self.index]["labels"][cfg.CLASSES[i]])))
            # for i,class_i in enumerate(cfg.CLASSES):
            #     self.colorblock_i[i].config(text=str(len(self.database[self.index]["labels"][class_i])))

    def update_text_i(self, clear=False):
        """ update 1. file view progress and 2. image view progress """
        if clear:
            self.n_count_i.config(text="no progress")
        else:
            self.n_count_i.config(text="files: {} / {}".format(self.index+1, len(self.database)))
        if hasattr(self, "cursor"):
            self.image_pro.config(text="images in view: {} / {}".format(self.cursor+1, len(self.image_list)))
        else:
            self.image_pro.config(text="no progress")

    def update_color_i(self, finished=False, clear=False):
        """ change colorbox color
        :param finished: the status of viewing images. Change color to shallow yellow if started viewing, green if finished.
        :param clear: the command to reset color to white, at tab changes
        """
        if clear:
            for clb in self.colorblock_i:
                clb.configure(background="#ffffff")
            return
        color = "#00ff00" if finished else "#ffff99"
        checked_classes = [i for i,var in enumerate(self.checkboxes_i) if var.get()]
        for i,clb in enumerate(self.colorblock_i):
            if i in checked_classes:
                clb.configure(background=color)


    def update_i(self, step):
        """ update method for second tab """
        if self.index is None:
            messagebox.showinfo("error", "there is no file/label matched")
            return
        if step == 0:
            self.load_images()
        if hasattr(self, "cursor"):
            self.update_images(step=step)
            self.load_thumbnailmini()
        self.update_text_i()
        self.update_label_counts_i()


    def save_labels(self, index=None):
        """ upload labels changes to patcher """
        if hasattr(self, "patcher"):
            if index is None:
                if self.index is None:
                    return
                else:
                    index = self.index
            if index not in range(len(self.database)):  # happens when failed to load new label files
                return
            self.patcher.set_labels(self.database[index]["labels"])
            if self.database[index]["lname"] is None:  # happens when no label file is loaded
                return
            # use the same directory if new label file dir is not set
            if not hasattr(self, "save_dir") or not self.save_dir:
                self.save_dir = os.path.dirname(self.database[index]["lname"])
            label_file = os.path.join(self.save_dir, self.database[index]["basename"]+".xml")
            self.patcher.write_labels(label_file)
            self.logger.log_save(self.database[index]["basename"])


    def on_visibility(self, event):
        """ clear up and update second tab upon tab switch, when self.index changed in first tab """
        if not hasattr(self, "old_index"):
            self.old_index = None
            self.save_dir = None
        if self.old_index != self.index:
            self.thumbmini.delete("all")  # delete thumbnail image
            self.display_i.delete("all")  # delete all images on canvas
            self.update_color_i(clear=True)  # reset background color in colorboxes
            if self.old_index is not None:  # save labels to label file
                self.save_labels(self.old_index)
            if self.index is None:  # happens when failed to load new label files
                self.update_text_i(clear=True)
                self.update_label_counts_i(clear=True)  # clear label counts on second tab
            else:
                self.update_text_i()
                self.update_label_counts_i()  # update label counts on second tab
                self.logger.log_open(self.database[self.index]["basename"])
            self.old_index = self.index


    def on_close(self):
        """ actions performed when window closed """
        try:
            self.save_labels(self.index)
            self.logger.on_close()
            del self.database
            self.root.destroy()
        except:
            print("failed to clear up program.")


    def run(self):
        self.root.mainloop()