Exemple #1
0
    def create_mask(self, u0):
        st = self.st

        rows = u0.shape[0]
        cols = u0.shape[1]

        kk = numpy.arange(0, rows)
        jj = numpy.arange(0, cols)

        kk = CsSolver.appendmat(kk, cols)
        jj = CsSolver.appendmat(jj, rows).T
        st['mask'] = numpy.ones((rows, cols), dtype=numpy.float32)

        #add circular mask
        sp_rat = (rows**2 + cols**2) * 1.0

        #         for jj in numpy.arange(0,cols):
        #             for kk in numpy.arange(0,rows):
        #                 if ( (kk-rows/2.0)**2+(jj-cols/2.0)**2 )/sp_rat > 1.0/8.0:
        #                     st['mask'][kk,jj] = 0.0

        if numpy.size(u0.shape) > 2:
            for pp in range(2, numpy.size(u0.shape)):
                st['mask'] = CsSolver.appendmat(st['mask'], u0.shape[pp])

        return st
    def shrink(self, dd, bb, thrsld):
        """
        soft-thresholding the edges
        
        """
        #        dd2 = ()
        #        bb2 = ()
        #        for pp in range(0,2):
        #            dd2=dd2 + (dd[pp]/100.0,)
        #            bb2=bb2+ (bb[pp]/100.0,)
        #        dd2 = dd2 +dd[2:]
        #        bb2 = bb2 +bb[2:]
        #        tmp_xx=CsSolver.shrink( dd2[0:2], bb2[0:2], thrsld)
        #
        #        output_xx = ()
        #        for pp in range(0,2):
        #            output_xx = output_xx + (tmp_xx[pp]*100.0,)
        #
        #        output_xx = output_xx + (tmp_xx[2],)

        output_xx = CsSolver.shrink(dd[0:2], bb[0:2], thrsld)  # 3D thresholding
        output_xx = output_xx + CsSolver.shrink(dd[2:3], bb[2:3], thrsld)  # 3D thresholding
        output_xx = output_xx + CsSolver.shrink(dd[3:4], bb[3:4], thrsld)
        output_xx = output_xx + CsSolver.shrink(dd[4:5], bb[4:5], thrsld)

        return output_xx  # +  output_x2
    def create_mask(self, u0):
        st = self.st

        rows = u0.shape[0]
        cols = u0.shape[1]

        kk = numpy.arange(0, rows)
        jj = numpy.arange(0, cols)

        kk = CsSolver.appendmat(kk, cols)
        jj = CsSolver.appendmat(jj, rows).T
        st["mask"] = numpy.ones((rows, cols), dtype=numpy.float32)

        # add circular mask
        sp_rat = (rows ** 2 + cols ** 2) * 1.0

        #         for jj in numpy.arange(0,cols):
        #             for kk in numpy.arange(0,rows):
        #                 if ( (kk-rows/2.0)**2+(jj-cols/2.0)**2 )/sp_rat > 1.0/8.0:
        #                     st['mask'][kk,jj] = 0.0

        if numpy.size(u0.shape) > 2:
            for pp in range(2, numpy.size(u0.shape)):
                st["mask"] = CsSolver.appendmat(st["mask"], u0.shape[pp])

        return st
Exemple #4
0
    def shrink(self, dd, bb, thrsld):
        '''
        soft-thresholding the edges
        
        '''
        #        dd2 = ()
        #        bb2 = ()
        #        for pp in range(0,2):
        #            dd2=dd2 + (dd[pp]/100.0,)
        #            bb2=bb2+ (bb[pp]/100.0,)
        #        dd2 = dd2 +dd[2:]
        #        bb2 = bb2 +bb[2:]
        #        tmp_xx=CsSolver.shrink( dd2[0:2], bb2[0:2], thrsld)
        #
        #        output_xx = ()
        #        for pp in range(0,2):
        #            output_xx = output_xx + (tmp_xx[pp]*100.0,)
        #
        #        output_xx = output_xx + (tmp_xx[2],)

        output_xx = CsSolver.shrink(dd[0:2], bb[0:2],
                                    thrsld)  # 3D thresholding
        output_xx = output_xx + CsSolver.shrink(dd[2:3], bb[2:3],
                                                thrsld)  # 3D thresholding
        output_xx = output_xx + CsSolver.shrink(dd[3:4], bb[3:4], thrsld)
        output_xx = output_xx + CsSolver.shrink(dd[4:5], bb[4:5], thrsld)

        return output_xx  #+  output_x2
Exemple #5
0
    def solve(self):  # main function of solver

        u0 = numpy.empty((self.f.dim_x, self.st['Nd'][0], self.st['Nd'][1], 4))

        tse = self.f.tse

        tse = CsSolver.appendmat(tse, u0.shape[2])
        tse = numpy.transpose(tse, (0, 1, 3, 2))

        print('tse.shape', tse.shape)

        #===============================================================================
        # mask
        #===============================================================================
        self.st = self.create_mask(u0)
        print('mask.shape', self.st['mask'].shape)

        #        for jj in range(0,16):
        #            matplotlib.pyplot.subplot(4,4,jj)
        #            matplotlib.pyplot.imshow(self.st['mask'][...,jj,0].real)
        #        matplotlib.pyplot.show()

        #===============================================================================

        #estimate sensitivity maps by divided by rms images
        self.st = self.make_sense(
            self.f.tse)  # setting up sense map in st['sensemap']

        self.st['sensemap'] = CsSolver.appendmat(self.st['sensemap'],
                                                 u0.shape[2])
        self.st['sensemap'] = numpy.transpose(self.st['sensemap'],
                                              (0, 1, 3, 2))

        #self.st['sensemap'] =self.st['sensemap'] * self.st['mask']
        print('self.sense.shape', self.st['sensemap'].shape)

        #        for jj in range(0,16):
        #            matplotlib.pyplot.subplot(4,4,jj)
        #            matplotlib.pyplot.imshow(numpy.abs(self.st['sensemap'][...,jj,0]))
        #        matplotlib.pyplot.show()

        self.st['senseflag'] = 1  # turn-on sense, to get sensemap

        (u, uf) = self.kernel(self.f.f, self.st, self.mu, self.LMBD,
                              self.gamma, self.nInner, self.nBreg)
        self.u = u
    def solve(self):  # main function of solver

        u0 = numpy.empty((self.f.dim_x, self.st["Nd"][0], self.st["Nd"][1], 4))

        tse = self.f.tse

        tse = CsSolver.appendmat(tse, u0.shape[2])
        tse = numpy.transpose(tse, (0, 1, 3, 2))

        print("tse.shape", tse.shape)

        # ===============================================================================
        # mask
        # ===============================================================================
        self.st = self.create_mask(u0)
        print("mask.shape", self.st["mask"].shape)

        #        for jj in range(0,16):
        #            matplotlib.pyplot.subplot(4,4,jj)
        #            matplotlib.pyplot.imshow(self.st['mask'][...,jj,0].real)
        #        matplotlib.pyplot.show()

        # ===============================================================================

        # estimate sensitivity maps by divided by rms images
        self.st = self.make_sense(self.f.tse)  # setting up sense map in st['sensemap']

        self.st["sensemap"] = CsSolver.appendmat(self.st["sensemap"], u0.shape[2])
        self.st["sensemap"] = numpy.transpose(self.st["sensemap"], (0, 1, 3, 2))

        # self.st['sensemap'] =self.st['sensemap'] * self.st['mask']
        print("self.sense.shape", self.st["sensemap"].shape)

        #        for jj in range(0,16):
        #            matplotlib.pyplot.subplot(4,4,jj)
        #            matplotlib.pyplot.imshow(numpy.abs(self.st['sensemap'][...,jj,0]))
        #        matplotlib.pyplot.show()

        self.st["senseflag"] = 1  # turn-on sense, to get sensemap

        (u, uf) = self.kernel(self.f.f, self.st, self.mu, self.LMBD, self.gamma, self.nInner, self.nBreg)
        self.u = u
Exemple #7
0
    def cg_step(self, rhs, m, uker, n_iter):
        FhFWm = self.do_FhWFm(
            m * self.st['sensemap']) * self.st['sensemap'].conj()

        FhFWm = CsSolver.CombineMulti(FhFWm, -1)

        lapla_m = self.do_laplacian(m, uker)

        lhs = FhFWm - self.LMBD * lapla_m + 2.0 * self.gamma * m

        #C_m= lhs - rhs
        r = rhs - lhs
        p = r

        for pp in range(0, n_iter):

            Ap = self.do_FhWFm(
                p * self.st['sensemap']) * self.st['sensemap'].conj()
            Ap = CsSolver.CombineMulti(Ap, -1)

            Ap = Ap - self.LMBD * self.do_laplacian(
                p, uker) + 2.0 * self.gamma * p

            upper_ratio = numpy.sum((r.conj() * r)[:])
            lower_ratio = numpy.sum((p.conj() * Ap)[:])
            alfa_k = upper_ratio / lower_ratio

            print('r', upper_ratio, 'alpha_k', alfa_k)
            #alfa_k = 0.3
            m = m + alfa_k * p
            r2 = r - alfa_k * Ap

            beta_k = numpy.sum((r2.conj() * r2)[:]) / numpy.sum(
                (r.conj() * r)[:])
            r = r2
            p = r + beta_k * p

        return m
    def external_update(self, u, f, uf, f0, u0):  # overload the update function

        CsSolver.checkmax(self.st["sensemap"])
        tmpuf = u * self.st["sensemap"]

        tmpuf = numpy.transpose(tmpuf, (1, 2, 3, 0))
        tmp_shape = tmpuf.shape
        tmpuf = numpy.reshape(tmpuf, tmp_shape[0:2] + (numpy.prod(tmp_shape[2:4]),), order="F")
        tmpuf = self.CsTransform.forwardbackward(tmpuf)
        tmpuf = numpy.reshape(tmpuf, tmp_shape, order="F")
        tmpuf = numpy.transpose(tmpuf, (3, 0, 1, 2))
        tmpuf = tmpuf * self.st["sensemap"].conj()

        #        tmpuf=self.st['sensemap'].conj()*(
        #                self.CsTransform.forwardbackward(
        #                        u*self.st['sensemap']))

        if self.st["senseflag"] == 1:
            tmpuf = CsSolver.CombineMulti(tmpuf, -1)

        print("start of ext_update")

        #        checkmax(u)
        #        checkmax(tmpuf)
        #        checkmax(self.u0)
        #        checkmax(uf)

        fact = numpy.sum((self.u0 - tmpuf) ** 2) / numpy.sum((u0) ** 2)
        fact = numpy.abs(fact.real)
        fact = numpy.sqrt(fact)
        print("fact", fact)
        #        fact=1.0/(1.0+numpy.exp(-(fact-0.5)*self.thresh_scale))
        tmpuf = CsSolver.Normalize(tmpuf) * numpy.max(numpy.abs(u0[:]))
        uf = uf + (u0 - tmpuf) * 1.0  # *fact
        uf = CsSolver.Normalize(uf) * numpy.max(numpy.abs(u0[:]))

        CsSolver.checkmax(tmpuf)
        CsSolver.checkmax(u0)
        CsSolver.checkmax(uf)

        #        for jj in range(0,u.shape[-1]):
        #            u[...,jj] = u[...,jj]*self.st['sn']# rescale the final image intensity

        print("end of ext_update")
        murf = uf
        return (f, uf, murf, u)
Exemple #9
0
    def constraint(self, xx, bb):
        '''
        include TVconstraint and others
        '''
        cons = CsSolver.TVconstraint(xx[0:2], bb[0:2]) * self.LMBD / 40.0
        #cons =  CsSolver.TVconstraint(xx[0:2],bb[0:2]) * self.LMBD/100.0
        #cons = cons + CsSolver.TVconstraint(xx[2:3],bb[2:3]) * self.LMBD
        cons = cons + scipy.fftpack.ifftn(xx[3] - bb[3],
                                          axes=(2, )) * self.gamma
        cons = cons + (xx[4] - bb[4]) * self.gamma
        #cons = cons + xx[2]-bb[2]
        #print('inside constraint, cons.shpae',cons.shape)
        #        cons = cons + freq_gradient_H(xx[3]-bb[3])
        #print('inside constraint 1117, cons.shpae',cons.shape)

        return cons
    def update_d(self, u, dd):
        #        print('inside_update_d ushape',u.shape)
        #        print('inside_update_d fre grad ushape',freq_gradient(u).shape)
        out_dd = ()
        for jj in range(0, len(dd)):
            if jj < 3:  # derivative y
                # tmp_d =get_Diff(u,jj)
                out_dd = out_dd + (CsSolver.get_Diff(u, jj),)
            elif jj == 3:  # rho
                tmpu = numpy.copy(u)
                tmpu = scipy.fftpack.fftn(tmpu, axes=(2,))
                #                 tmpu[:,:,0,:] = tmpu[:,:,0,:]*0.0
                out_dd = out_dd + (tmpu,)

            elif jj == 4:
                average_u = numpy.sum(u, 2)
                tmpu = numpy.copy(u)
                #                 for jj in range(0,u.shape[2]):
                #                     tmpu[:,:,jj,:]= tmpu[:,:,jj,:] - average_u
                out_dd = out_dd + (tmpu,)
        #            elif jj == 3:
        #                out_dd = out_dd + (freq_gradient(u),)

        return out_dd
Exemple #11
0
    def update_d(self, u, dd):
        #        print('inside_update_d ushape',u.shape)
        #        print('inside_update_d fre grad ushape',freq_gradient(u).shape)
        out_dd = ()
        for jj in range(0, len(dd)):
            if jj < 3:  # derivative y
                #tmp_d =get_Diff(u,jj)
                out_dd = out_dd + (CsSolver.get_Diff(u, jj), )
            elif jj == 3:  # rho
                tmpu = numpy.copy(u)
                tmpu = scipy.fftpack.fftn(tmpu, axes=(2, ))
                #                 tmpu[:,:,0,:] = tmpu[:,:,0,:]*0.0
                out_dd = out_dd + (tmpu, )

            elif jj == 4:
                average_u = numpy.sum(u, 2)
                tmpu = numpy.copy(u)
                #                 for jj in range(0,u.shape[2]):
                #                     tmpu[:,:,jj,:]= tmpu[:,:,jj,:] - average_u
                out_dd = out_dd + (tmpu, )
#            elif jj == 3:
#                out_dd = out_dd + (freq_gradient(u),)

        return out_dd
Exemple #12
0
def foo4():
    import GeRaw.pfileparser
    filename = '/home/sram/Cambridge_2012/DATA_MATLAB/chengcheng/cube_raw_20130704.raw'
    rawdata = GeRaw.pfileparser.geV22(filename)

    while len(rawdata.k.shape) > 4:
        rawdata.k = rawdata.k[..., 0]

    print('point size', rawdata.hdr['rdb']['point_size'])
    print(numpy.shape(rawdata.k), numpy.shape(rawdata.k)[1:3])

    #     R,ratio=makeRandom(rawdata.k.shape[1:3],0.45,(0.15,0.35))

    R, ratio = loadRandom(0.5)
    #     import scipy.io
    #     R = numpy.loadtxt('zerop3')
    #     ratio = sum(sum(R))/224.0/40.0

    print('true ratio', ratio)
    matplotlib.pyplot.imshow(R)
    matplotlib.pyplot.show()
    ind = convert_mask_to_index(R)
    print(ind)
    matplotlib.pyplot.plot(ind[:, 1], ind[:, 0], 'x')
    matplotlib.pyplot.show()
    om = ind
    Nd = (224, 40)
    Kd = (224, 40)
    Jd = (1, 1)
    x = rawdata.k
    x = scipy.fftpack.fftshift(x, axes=(1, 2))
    x = pyfftw.interfaces.scipy_fftpack.fftn(x, axes=(1, 2), threads=2)
    x = scipy.fftpack.fftshift(x, axes=(1, 2))

    x = scipy.fftpack.fftshift(x, axes=(0, ))
    x = pyfftw.interfaces.scipy_fftpack.fftn(x, axes=(0, ), threads=2)
    x = scipy.fftpack.fftshift(x, axes=(0, ))

    MyTransform = CsTransform.pynufft.pynufft(om, Nd, Kd, Jd)
    #     Cartesian3DObj = Cartesian3DSolver(om, (224,40), (224,40),(1,1))

    original = x
    recon = numpy.empty((224, 224, 40, 4), dtype=numpy.complex)
    backward = numpy.empty_like(recon)
    #     for jj in range(0,224):
    #         for kk in range(0,4):
    #             print(jj/224.0,kk/224.0)

    #     jj = 12

    #     c=numpy.transpose(c,(1,0))
    #     c=c.real/numpy.max(numpy.abs(c[:]))
    f = numpy.empty((MyTransform.st['M'], ), dtype=numpy.complex)
    for jj in range(0, 224):
        for kk in range(0, 4):
            print(jj / 224.0, kk / 4.0)
            c = x[jj, :, :, kk]
            f = MyTransform.forward(c)
            backward[jj, :, :, kk] = MyTransform.backward(f)[:, :, 0]
            Solver1 = CsSolver.pyCube2D(MyTransform, f, 1.0, 0.1, 0.001, 4, 45)
            Solver1.solve()
            recon[jj, :, :, kk] = Solver1.u


#     return
#    myu0 = MyTransform.forwardbackward(c)
#    Solver1=CsSolver.CsSolver.isra(MyTransform,f,40)
#     Solver1=CsSolver.pyCube2D(MyTransform, f, 1.0, 0.1, 0.001, 1,15)

    numpy.save('original', original)

    numpy.save('recon', recon)

    numpy.save('backward', backward)
    def kernel(self, f_internal, st, mu, LMBD, gamma, nInner, nBreg):
        self.st["sensemap"] = self.st["sensemap"] * self.st["mask"]
        tse = self.f.tse
        #        tse=numpy.abs(numpy.mean(self.st['sensemap'],-1))

        tse = CsSolver.appendmat(tse, self.st["Nd"][1])
        # tse=Normalize(tse)
        tse = numpy.transpose(tse, (0, 1, 3, 2))
        self.ttse = CsSolver.Normalize(tse)

        self.tse0 = CsSolver.CombineMulti(tse, -1)

        self.filter = numpy.ones(tse.shape)
        dpss = numpy.kaiser(tse.shape[1], 1.0) * 10.0
        for ppp in range(0, tse.shape[1]):
            self.filter[:, ppp, :, :] = self.filter[:, ppp, :, :] * dpss[ppp]

        print("tse.shape", tse.shape)
        #        L= numpy.size(f)/st['M']
        #        image_dim=st['Nd']+(L,)
        #
        #        if numpy.ndim(f) == 1:# preventing row vector
        #            f=numpy.reshape(f,(numpy.shape(f)[0],1),order='F')
        #        f0 = numpy.copy(f) # deep copy to prevent scope f0 to f
        ##        u = numpy.zeros(image_dim,dtype=numpy.complex64)
        f0 = numpy.copy(f_internal)
        f = numpy.copy(f_internal)

        #        u0=self.data2rho(f_internal,
        #                         self.f.dim_x,
        #                         self.st['Nd'][0],
        #                         self.st['Nd'][1],
        #                         self.f.ncoils,
        #                         self.CsTransform
        #                         ) # doing spatial transform
        u0 = self.fun1(f_internal)

        pdf = self.f.pdf
        pdf = CsSolver.appendmat(pdf, self.st["Nd"][1])
        pdf = numpy.transpose(pdf, (0, 1, 3, 2))

        #        u0 = scipy.fftpack.fftn(u0,axes=(1,))
        #        u0 = scipy.fftpack.fftshift(u0,axes=(1,))
        #        #u0[:,:,u0.shape[2]/2,:] = u0[:,:,u0.shape[2]/2,:]/pdf[:,:,u0.shape[2]/2,:]
        #        u0 = u0#/pdf
        #        u0 = scipy.fftpack.ifftshift(u0,axes=(1,))
        #        u0 = scipy.fftpack.ifftn(u0,axes=(1,))

        #        print('self.f.pdf.shape',self.f.pdf.shape)
        #        for pj in range(0,4):
        #            matplotlib.pyplot.imshow(self.f.pdf[:,:,pj].real)
        #            matplotlib.pyplot.show()

        u0 = self.fun2(u0)

        u0 = self.fun3(u0)

        u0 = u0 * self.st["sensemap"].conj()

        u0 = CsSolver.CombineMulti(u0, -1)

        # u0 = u0*self.filter

        uker = self.create_laplacian_kernel()
        uker = CsSolver.appendmat(uker, u0.shape[3])

        self.u0 = u0

        u = numpy.copy(self.tse0)

        print("u0.shape", u0.shape)

        (xx, bb, dd) = self.make_split_variables(u)

        uf = numpy.copy(u)  # only used for ISRA, written here for generality

        murf = numpy.copy(u)  # initial values
        #    #===============================================================================
        # u_stack = numpy.empty(st['Nd']+(nBreg,),dtype=numpy.complex)
        for outer in numpy.arange(0, nBreg):
            for inner in numpy.arange(0, nInner):
                # update u
                print("iterating", [inner, outer])
                # ===============================================================
                #                 update u  # simple k-space deconvolution to guess initial u
                u = self.update_u(murf, u, uker, xx, bb)

                c = numpy.max(numpy.abs(u[:]))  # Rough coefficient
                # to correct threshold of nonlinear shrink

                # ===================================================================
                # # update d
                # ===================================================================
                # ===================================================================
                # Shrinkage: remove tiny values "in somewhere sparse!"
                # dx+bx should be sparse!
                # ===================================================================
                # shrinkage
                # ===================================================================
                dd = self.update_d(u, dd)

                xx = self.shrink(dd, bb, c * 1.0 / LMBD / numpy.sqrt(numpy.prod(st["Nd"])))

                # ===============================================================
                # ===================================================================
                # # update b
                # ===================================================================

                bb = self.update_b(bb, dd, xx)

            #            if outer < nBreg: # do not update in the last loop
            (f, uf, murf, u) = self.external_update(u, f, uf, f0, u0)  # update outer Split_bregman

        u = CsSolver.Normalize(u)
        for pp in range(0, u0.shape[2]):
            matplotlib.pyplot.subplot(numpy.sqrt(u0.shape[2]) + 1, numpy.sqrt(u0.shape[2]) + 1, pp)
            matplotlib.pyplot.imshow(numpy.sum(numpy.abs(u[..., pp, :]), -1), norm=norm, interpolation="nearest")
        matplotlib.pyplot.show()
        #

        return (u, uf)
Exemple #14
0
    def external_update(self, u, f, uf, f0,
                        u0):  # overload the update function

        CsSolver.checkmax(self.st['sensemap'])
        tmpuf = u * self.st['sensemap']

        tmpuf = numpy.transpose(tmpuf, (1, 2, 3, 0))
        tmp_shape = tmpuf.shape
        tmpuf = numpy.reshape(tmpuf,
                              tmp_shape[0:2] + (numpy.prod(tmp_shape[2:4]), ),
                              order='F')
        tmpuf = self.CsTransform.forwardbackward(tmpuf)
        tmpuf = numpy.reshape(tmpuf, tmp_shape, order='F')
        tmpuf = numpy.transpose(tmpuf, (3, 0, 1, 2))
        tmpuf = tmpuf * self.st['sensemap'].conj()

        #        tmpuf=self.st['sensemap'].conj()*(
        #                self.CsTransform.forwardbackward(
        #                        u*self.st['sensemap']))

        if self.st['senseflag'] == 1:
            tmpuf = CsSolver.CombineMulti(tmpuf, -1)

        print('start of ext_update')

        #        checkmax(u)
        #        checkmax(tmpuf)
        #        checkmax(self.u0)
        #        checkmax(uf)

        fact = numpy.sum((self.u0 - tmpuf)**2) / numpy.sum((u0)**2)
        fact = numpy.abs(fact.real)
        fact = numpy.sqrt(fact)
        print('fact', fact)
        #        fact=1.0/(1.0+numpy.exp(-(fact-0.5)*self.thresh_scale))
        tmpuf = CsSolver.Normalize(tmpuf) * numpy.max(numpy.abs(u0[:]))
        uf = uf + (u0 - tmpuf) * 1.0  #*fact
        uf = CsSolver.Normalize(uf) * numpy.max(numpy.abs(u0[:]))

        CsSolver.checkmax(tmpuf)
        CsSolver.checkmax(u0)
        CsSolver.checkmax(uf)

        #        for jj in range(0,u.shape[-1]):
        #            u[...,jj] = u[...,jj]*self.st['sn']# rescale the final image intensity

        print('end of ext_update')
        murf = uf
        return (f, uf, murf, u)
Exemple #15
0
    def kernel(self, f_internal, st, mu, LMBD, gamma, nInner, nBreg):
        self.st['sensemap'] = self.st['sensemap'] * self.st['mask']
        tse = self.f.tse
        #        tse=numpy.abs(numpy.mean(self.st['sensemap'],-1))

        tse = CsSolver.appendmat(tse, self.st['Nd'][1])
        #tse=Normalize(tse)
        tse = numpy.transpose(tse, (0, 1, 3, 2))
        self.ttse = CsSolver.Normalize(tse)

        self.tse0 = CsSolver.CombineMulti(tse, -1)

        self.filter = numpy.ones(tse.shape)
        dpss = numpy.kaiser(tse.shape[1], 1.0) * 10.0
        for ppp in range(0, tse.shape[1]):
            self.filter[:, ppp, :, :] = self.filter[:, ppp, :, :] * dpss[ppp]

        print('tse.shape', tse.shape)
        #        L= numpy.size(f)/st['M']
        #        image_dim=st['Nd']+(L,)
        #
        #        if numpy.ndim(f) == 1:# preventing row vector
        #            f=numpy.reshape(f,(numpy.shape(f)[0],1),order='F')
        #        f0 = numpy.copy(f) # deep copy to prevent scope f0 to f
        ##        u = numpy.zeros(image_dim,dtype=numpy.complex64)
        f0 = numpy.copy(f_internal)
        f = numpy.copy(f_internal)

        #        u0=self.data2rho(f_internal,
        #                         self.f.dim_x,
        #                         self.st['Nd'][0],
        #                         self.st['Nd'][1],
        #                         self.f.ncoils,
        #                         self.CsTransform
        #                         ) # doing spatial transform
        u0 = self.fun1(f_internal)

        pdf = self.f.pdf
        pdf = CsSolver.appendmat(pdf, self.st['Nd'][1])
        pdf = numpy.transpose(pdf, (0, 1, 3, 2))

        #        u0 = scipy.fftpack.fftn(u0,axes=(1,))
        #        u0 = scipy.fftpack.fftshift(u0,axes=(1,))
        #        #u0[:,:,u0.shape[2]/2,:] = u0[:,:,u0.shape[2]/2,:]/pdf[:,:,u0.shape[2]/2,:]
        #        u0 = u0#/pdf
        #        u0 = scipy.fftpack.ifftshift(u0,axes=(1,))
        #        u0 = scipy.fftpack.ifftn(u0,axes=(1,))

        #        print('self.f.pdf.shape',self.f.pdf.shape)
        #        for pj in range(0,4):
        #            matplotlib.pyplot.imshow(self.f.pdf[:,:,pj].real)
        #            matplotlib.pyplot.show()

        u0 = self.fun2(u0)

        u0 = self.fun3(u0)

        u0 = u0 * self.st['sensemap'].conj()

        u0 = CsSolver.CombineMulti(u0, -1)

        #u0 = u0*self.filter

        uker = self.create_laplacian_kernel()
        uker = CsSolver.appendmat(uker, u0.shape[3])

        self.u0 = u0

        u = numpy.copy(self.tse0)

        print('u0.shape', u0.shape)

        (xx, bb, dd) = self.make_split_variables(u)

        uf = numpy.copy(u)  # only used for ISRA, written here for generality

        murf = numpy.copy(u)  # initial values
        #    #===============================================================================
        #u_stack = numpy.empty(st['Nd']+(nBreg,),dtype=numpy.complex)
        for outer in numpy.arange(0, nBreg):
            for inner in numpy.arange(0, nInner):
                # update u
                print('iterating', [inner, outer])
                #===============================================================
                #                 update u  # simple k-space deconvolution to guess initial u
                u = self.update_u(murf, u, uker, xx, bb)

                c = numpy.max(numpy.abs(u[:]))  # Rough coefficient
                # to correct threshold of nonlinear shrink

                #===================================================================
                # # update d
                #===================================================================
                #===================================================================
                # Shrinkage: remove tiny values "in somewhere sparse!"
                # dx+bx should be sparse!
                #===================================================================
                # shrinkage
                #===================================================================
                dd = self.update_d(u, dd)

                xx = self.shrink(
                    dd, bb, c * 1.0 / LMBD / numpy.sqrt(numpy.prod(st['Nd'])))

                #===============================================================
                #===================================================================
                # # update b
                #===================================================================

                bb = self.update_b(bb, dd, xx)

#            if outer < nBreg: # do not update in the last loop
            (f, uf, murf,
             u) = self.external_update(u, f, uf, f0,
                                       u0)  # update outer Split_bregman

        u = CsSolver.Normalize(u)
        for pp in range(0, u0.shape[2]):
            matplotlib.pyplot.subplot(
                numpy.sqrt(u0.shape[2]) + 1,
                numpy.sqrt(u0.shape[2]) + 1, pp)
            matplotlib.pyplot.imshow(numpy.sum(numpy.abs(u[..., pp, :]), -1),
                                     norm=norm,
                                     interpolation='nearest')
        matplotlib.pyplot.show()
        #

        return (u, uf)
def foo4():
    import GeRaw.pfileparser
    filename='/home/sram/Cambridge_2012/DATA_MATLAB/chengcheng/cube_raw_20130704.raw'
    rawdata=GeRaw.pfileparser.geV22(filename)
    
    while len(rawdata.k.shape) > 4:
        rawdata.k= rawdata.k[...,0]    
    
    
    
    print('point size',rawdata.hdr['rdb']['point_size'])
    print(numpy.shape(rawdata.k),numpy.shape(rawdata.k)[1:3])

#     R,ratio=makeRandom(rawdata.k.shape[1:3],0.45,(0.15,0.35))
    
    R,ratio = loadRandom(0.5)
#     import scipy.io
#     R = numpy.loadtxt('zerop3')
#     ratio = sum(sum(R))/224.0/40.0

    print('true ratio',ratio)
    matplotlib.pyplot.imshow(R)
    matplotlib.pyplot.show()
    ind = convert_mask_to_index(R)
    print(ind)
    matplotlib.pyplot.plot(ind[:,1],ind[:,0],'x')
    matplotlib.pyplot.show()
    om = ind
    Nd = (224,40)
    Kd = (224,40)
    Jd = (1,1)
    x = rawdata.k
    x = scipy.fftpack.fftshift(x,axes = (1,2))
    x = pyfftw.interfaces.scipy_fftpack.fftn(x,axes = (1,2),threads=2)
    x = scipy.fftpack.fftshift(x,axes = (1,2))

    x = scipy.fftpack.fftshift(x,axes = (0,))
    x = pyfftw.interfaces.scipy_fftpack.fftn(x,axes = (0,),threads=2)
    x = scipy.fftpack.fftshift(x,axes = (0,))

    MyTransform = CsTransform.pynufft.pynufft( om, Nd,Kd,Jd)
#     Cartesian3DObj = Cartesian3DSolver(om, (224,40), (224,40),(1,1))

    original = x
    recon = numpy.empty((224,224,40,4),dtype = numpy.complex)
    backward = numpy.empty_like(recon)
#     for jj in range(0,224):
#         for kk in range(0,4):
#             print(jj/224.0,kk/224.0)

#     jj = 12
    

#     c=numpy.transpose(c,(1,0))
#     c=c.real/numpy.max(numpy.abs(c[:]))
    f = numpy.empty((MyTransform.st['M'],),dtype = numpy.complex)
    for jj in range(0,224):
        for kk in range(0,4):
            print(jj/224.0, kk/4.0)
            c=x[jj,:,:,kk]
            f=MyTransform.forward(c)
            backward[jj,:,:,kk] = MyTransform.backward(f)[:,:,0]
            Solver1=CsSolver.pyCube2D(MyTransform, f, 1.0, 0.1, 0.001, 4,45)
            Solver1.solve()    
            recon[jj,:,:,kk] = Solver1.u
#     return
#    myu0 = MyTransform.forwardbackward(c)
#    Solver1=CsSolver.CsSolver.isra(MyTransform,f,40)
#     Solver1=CsSolver.pyCube2D(MyTransform, f, 1.0, 0.1, 0.001, 1,15)

    numpy.save('original',original) 

    numpy.save('recon',recon) 

    numpy.save('backward',backward)