def AtAUV(x,csmT,maskT):
    atbv=torch.cuda.FloatTensor(nbasis,nx,nx,2).fill_(0)
    tmp6=torch.cuda.FloatTensor(nbasis,nx,nx,2).fill_(0)
    csmConj=sf.pt_conj(csmT)

    for i in range(nch):
        tmp1=sf.pt_cpx_multipy(x,csmT[i].repeat(nbasis,1,1,1))
        tmp2=sf.pt_fft2c(tmp1)
        del tmp1
        for k in range(NF1):
            tmp3=maskT[k].repeat(nbasis,1,1,1)*tmp2
            tmp4=VT[:,k].unsqueeze(1)
            tmp3=torch.reshape(tmp3,(nbasis,nx*nx*2))
            tmp5=tmp4.T@tmp3#torch.mm(tmp4.T,tmp3)
            tmp5=tmp4@tmp5#torch.mm(tmp4,tmp5)        
            tmp5=torch.reshape(tmp5,(nbasis,nx,nx,2))
            tmp6=tmp6+tmp5
        del tmp2,tmp3,tmp4,tmp5   
        tmp1=sf.pt_ifft2c(tmp6)
        tmp2=sf.pt_cpx_multipy(csmConj[i].repeat(nbasis,1,1,1),tmp1)
        atbv=atbv+tmp2
        tmp6=tmp6.fill_(0)
        del tmp1,tmp2
    x=torch.reshape(x,(nbasis,nx*nx*2))
    x=W*x
    reg=torch.mm(sT,x)
    reg=torch.reshape(reg,(nbasis,nx,nx,2))
    atbv=atbv+reg
    return atbv
def AtAUV(x, csmT, maskT1, VT):
    #    atbv=torch.zeros(nbasis,nx,nx,2)
    #    atbv=atbv.to(gpu)
    #    tmp2=torch.zeros(nbasis,nch,nx,nx,2)
    #    tmp2=tmp2.to(gpu)
    #tmp6=torch.zeros(nbasis,nx,nx,2)
    #tmp6=tmp6.to(gpu)
    nch = csmT.size(0)
    NF = maskT1.size(0)
    atbv = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0)
    tmp6 = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0)
    csmConj = sf.pt_conj(csmT)

    for i in range(nch):
        #tmp=csmT[i,:,:,:]
        #tmp=tmp.repeat(nbasis,1,1,1)
        tmp1 = sf.pt_cpx_multipy(x, csmT[i].repeat(nbasis, 1, 1, 1))
        tmp2 = sf.pt_fft2c(tmp1)
        del tmp1
        for k in range(NF):
            #tmp=maskT[k,:,:,:]
            #tmp=tmp.repeat(nbasis,1,1,1).to(gpu)
            tmp3 = maskT1[k].unsqueeze(2).repeat(nbasis, 1, 1, 2) * tmp2
            #tmp3=tmp3.to(gpu,dtype)
            tmp4 = VT[:, k].unsqueeze(1)
            tmp3 = torch.reshape(tmp3, (nbasis, nx * nx * 2))
            tmp5 = tmp4.T @ tmp3  #torch.mm(tmp4.T,tmp3)
            #tmp5=torch.matmul(tmp3.permute(1,2,3,0),tmp4)
            #tmp5=torch.matmul(tmp5,tmp4.T)
            tmp5 = tmp4 @ tmp5  #torch.mm(tmp4,tmp5)
            tmp5 = torch.reshape(tmp5, (nbasis, nx, nx, 2))
            #tmp5=tmp5.permute(3,0,1,2)
            tmp6 = tmp6 + tmp5
        del tmp2, tmp3, tmp4, tmp5
        tmp1 = sf.pt_ifft2c(tmp6)
        #tmp=csmConj[i,:,:,:]
        #tmp=tmp.repeat(nbasis,1,1,1).to(gpu)
        tmp2 = sf.pt_cpx_multipy(csmConj[i].repeat(nbasis, 1, 1, 1), tmp1)
        atbv = atbv + tmp2
        #tmp6=torch.zeros(nbasis,nx,nx,2)
        #tmp6=tmp6.to(gpu)
        tmp6 = tmp6.fill_(0)
        del tmp1, tmp2

    x = torch.reshape(x, (nbasis, nx * nx * 2))
    #x=W*x
    #reg=torch.mm(sT,x)
    #reg=torch.reshape(reg,(nbasis,nx,nx,2))
    #atbv=atbv+reg
    del x
    return atbv
def ATBV_NFT(FT, kdata, csmT, VT):

    nch = csmT.size(0)
    N = csmT.size(1)
    nbas = VT.size(0)
    tmp1 = torch.zeros((nx, nx, 2)).to(gpu)
    atbv = torch.zeros((nx * nx * 2, nbasis)).to(gpu)
    x = x.to(gpu)
    for i in range(nbas):
        for j in range(nch):
            tmp2 = VT[:, i].unsqueeze(1)
            tmp = torch.diag(tmp2) @ kdata[j, :, :]
            tmp = torch.reshape(tmp, (nx, nx, 2))
            tmp1[i] = tmp1[i] + sf.pt_cpx_multipy(FT(tmp), sf.pt_conj(csmT[j]))

        tmp3 = torch.reshape(tmp1, (nx * nx * 2, 1))
        atbv = atbv + tmp3 @ tmp2.T
        tmp1 = torch.zeros((nx, nx, 2)).to(gpu)

    atbv = atbv * nx
    atbv = atbv.permute(1, 0)
    atbv = torch.reshape(atbv, (nbasis, nx, nx, 2))
    #atbV=atbV.astype(np.complex64)
    #atbV=np.fft.fftshift(np.fft.fftshift(atbV,1),2)
    #atbV=np.transpose(atbV,[0,2,1])
    del tmp, tmp1, tmp2, tmp3
    return atbv
def ATBV(x, csmT, VT):

    nch = csmT.size(0)
    NF = VT.size(1)
    tmp1 = torch.zeros((nx, nx, 2)).to(gpu)
    atbv = torch.zeros((nx * nx * 2, nbasis)).to(gpu)
    x = x.to(gpu)
    for i in range(NF):
        for j in range(nch):
            tmp = x[j, i]
            tmp = torch.reshape(tmp, (nx, nx, 2))
            tmp1 = tmp1 + sf.pt_cpx_multipy(sf.pt_ifft2c(tmp),
                                            sf.pt_conj(csmT[j]))

        tmp2 = VT[:, i].unsqueeze(1)
        tmp3 = torch.reshape(tmp1, (nx * nx * 2, 1))
        atbv = atbv + tmp3 @ tmp2.T
        tmp1 = torch.zeros((nx, nx, 2)).to(gpu)

    atbv = atbv * nx
    atbv = atbv.permute(1, 0)
    atbv = torch.reshape(atbv, (nbasis, nx, nx, 2))
    #atbV=atbV.astype(np.complex64)
    #atbV=np.fft.fftshift(np.fft.fftshift(atbV,1),2)
    #atbV=np.transpose(atbV,[0,2,1])
    del tmp, tmp1, tmp2, tmp3
    return atbv
Ejemplo n.º 5
0
def pt_At(bT,csmT,maskT):
    tmp=maskT*bT
    tmp=sf.pt_ifft2c(tmp)
    csmConj=sf.pt_conj(csmT)
    tmp=sf.pt_cpx_multipy(csmConj, tmp)
    atbT=torch.sum( tmp ,dim=-4)
    return atbT
Ejemplo n.º 6
0
def pt_At(bT, csmT, maskT):
    tmp = maskT * bT
    tmp = sf.pt_ifft2c(tmp)
    csmConj = sf.pt_conj(csmT)
    for i in range(NF):
        tmp1 = sf.pt_cpx_multipy(csmConj, tmp[i])
        atbT[i] = torch.sum(tmp1, dim=-4)
    return atbT
    def forward(self,x,mask):
        #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0)
        tmp5=torch.cuda.FloatTensor(self.nch,self.NF,self.nx*self.nx*2).fill_(0)
        x=torch.reshape(x,(self.nbasis,self.nx,self.nx,2))
        for i in range(nch):
            tmp2=sf.pt_fft2c(sf.pt_cpx_multipy(x,self.csmT[i].repeat(self.nbasis,1,1,1)))
            tmp2=torch.reshape(tmp2,(self.nbasis,self.NX))
            tmp2=tmp2.repeat(self.NF,1,1)*mask
            tmp5[i]=tmp2.sum(axis=1)

        return tmp5
def AtAUV(x, csmT, csmConj, maskT):
    atbv = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0)

    for i in range(nch):
        tmp2 = sf.pt_fft2c(
            sf.pt_cpx_multipy(x, csmT[i].repeat(nbasis, 1, 1, 1)))
        tmp2 = torch.reshape(tmp2, (nbasis, nx * nx * 2))
        tmp2 = tmp2.repeat(nbasis, 1, 1) * maskT
        tmp = tmp2.sum(axis=1)
        tmp = torch.reshape(tmp, (nbasis, nx, nx, 2))
        tmp2 = sf.pt_cpx_multipy(csmConj[i].repeat(nbasis, 1, 1, 1),
                                 sf.pt_ifft2c(tmp))
        atbv = atbv + tmp2
        del tmp2
    x = torch.reshape(x, (nbasis, nx * nx * 2))
    x = W * x
    reg = torch.mm(sT, x)
    reg = torch.reshape(reg, (nbasis, nx, nx, 2))
    atbv = atbv + reg
    return atbv
Ejemplo n.º 9
0
    def forward(self, x):
        #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0)
        tmp5 = torch.FloatTensor(self.nch, self.NF,
                                 self.nx * self.nx * 2).fill_(0)
        x = torch.reshape(x, (self.nbasis, self.nx, self.nx, 2))
        for i in range(self.nch):
            tmp2 = sf.pt_fft2c(
                sf.pt_cpx_multipy(x, self.csmT[i].repeat(self.nbasis, 1, 1,
                                                         1)))
            tmp2 = torch.reshape(tmp2, (self.nbasis, self.NX))
            for j in range(self.NF):
                tmp3 = self.mask[j].repeat(nbasis, 1) * tmp2
                tmp3 = tmp3.T @ self.VT[:, k].unsqueeze(1)
                tmp5[i, j] = tmp3.T

        return tmp5
        def forward(self, x, mask, v1):
            #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0)
            nbas = x.shape[0]
            m_res = ss.maskV(mask.cuda(), v1)
            m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2)
            m_res = torch.reshape(m_res, (self.NF, nbasis, nx * nx * 2))
            tmp5 = torch.cuda.FloatTensor(self.nch, self.NF,
                                          self.nx * self.nx * 2).fill_(0)
            x = torch.reshape(x, (nbas, self.nx, self.nx, 2))
            for i in range(nch):
                tmp2 = sf.pt_fft2c(
                    sf.pt_cpx_multipy(x, self.csmT[i].repeat(nbas, 1, 1, 1)))
                tmp2 = torch.reshape(tmp2, (nbas, self.NX))
                tmp2 = tmp2.repeat(self.NF, 1, 1) * m_res
                tmp5[i] = tmp2.sum(axis=1)

            del tmp2, x
            return tmp5.cpu()
    def forward(self, x, mask, Vv):
        #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0)

        nf = mask.size(0)
        nbas = x.shape[0]
        mask = torch.reshape(mask, (nf, nx, nx))
        mask = mask.unsqueeze(3)

        tmp5 = torch.FloatTensor(self.nch, nf, nx, nx, 2).fill_(0)
        #tmp4=torch.FloatTensor(nf,nbas,nx*nx*2).fill_(0)
        x = torch.reshape(x, (nbas, nx * nx * 2))
        uv = Vv @ x
        uv = torch.reshape(uv, (nf, nx, nx, 2))
        for i in range(nch):
            tmp2 = sf.pt_fft2c(
                sf.pt_cpx_multipy(uv, csmT[i].repeat(nf, 1, 1, 1)))
            tmp5[i] = tmp2 * mask.repeat(1, 1, 1, 2)

        del tmp2, x
        return tmp5.cpu()
Ejemplo n.º 12
0
def pt_A(orgT,csmT,maskT):
    tmp=sf.pt_cpx_multipy(orgT,csmT)
    tmp=sf.pt_fft2c(tmp)
    bT=maskT*tmp
    return bT