def AtAUV(x,csmT,maskT): atbv=torch.cuda.FloatTensor(nbasis,nx,nx,2).fill_(0) tmp6=torch.cuda.FloatTensor(nbasis,nx,nx,2).fill_(0) csmConj=sf.pt_conj(csmT) for i in range(nch): tmp1=sf.pt_cpx_multipy(x,csmT[i].repeat(nbasis,1,1,1)) tmp2=sf.pt_fft2c(tmp1) del tmp1 for k in range(NF1): tmp3=maskT[k].repeat(nbasis,1,1,1)*tmp2 tmp4=VT[:,k].unsqueeze(1) tmp3=torch.reshape(tmp3,(nbasis,nx*nx*2)) tmp5=tmp4.T@tmp3#torch.mm(tmp4.T,tmp3) tmp5=tmp4@tmp5#torch.mm(tmp4,tmp5) tmp5=torch.reshape(tmp5,(nbasis,nx,nx,2)) tmp6=tmp6+tmp5 del tmp2,tmp3,tmp4,tmp5 tmp1=sf.pt_ifft2c(tmp6) tmp2=sf.pt_cpx_multipy(csmConj[i].repeat(nbasis,1,1,1),tmp1) atbv=atbv+tmp2 tmp6=tmp6.fill_(0) del tmp1,tmp2 x=torch.reshape(x,(nbasis,nx*nx*2)) x=W*x reg=torch.mm(sT,x) reg=torch.reshape(reg,(nbasis,nx,nx,2)) atbv=atbv+reg return atbv
def forward(self,x,mask): #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0) tmp5=torch.cuda.FloatTensor(self.nch,self.NF,self.nx*self.nx*2).fill_(0) x=torch.reshape(x,(self.nbasis,self.nx,self.nx,2)) for i in range(nch): tmp2=sf.pt_fft2c(sf.pt_cpx_multipy(x,self.csmT[i].repeat(self.nbasis,1,1,1))) tmp2=torch.reshape(tmp2,(self.nbasis,self.NX)) tmp2=tmp2.repeat(self.NF,1,1)*mask tmp5[i]=tmp2.sum(axis=1) return tmp5
def AtAUV(x, csmT, maskT1, VT): # atbv=torch.zeros(nbasis,nx,nx,2) # atbv=atbv.to(gpu) # tmp2=torch.zeros(nbasis,nch,nx,nx,2) # tmp2=tmp2.to(gpu) #tmp6=torch.zeros(nbasis,nx,nx,2) #tmp6=tmp6.to(gpu) nch = csmT.size(0) NF = maskT1.size(0) atbv = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0) tmp6 = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0) csmConj = sf.pt_conj(csmT) for i in range(nch): #tmp=csmT[i,:,:,:] #tmp=tmp.repeat(nbasis,1,1,1) tmp1 = sf.pt_cpx_multipy(x, csmT[i].repeat(nbasis, 1, 1, 1)) tmp2 = sf.pt_fft2c(tmp1) del tmp1 for k in range(NF): #tmp=maskT[k,:,:,:] #tmp=tmp.repeat(nbasis,1,1,1).to(gpu) tmp3 = maskT1[k].unsqueeze(2).repeat(nbasis, 1, 1, 2) * tmp2 #tmp3=tmp3.to(gpu,dtype) tmp4 = VT[:, k].unsqueeze(1) tmp3 = torch.reshape(tmp3, (nbasis, nx * nx * 2)) tmp5 = tmp4.T @ tmp3 #torch.mm(tmp4.T,tmp3) #tmp5=torch.matmul(tmp3.permute(1,2,3,0),tmp4) #tmp5=torch.matmul(tmp5,tmp4.T) tmp5 = tmp4 @ tmp5 #torch.mm(tmp4,tmp5) tmp5 = torch.reshape(tmp5, (nbasis, nx, nx, 2)) #tmp5=tmp5.permute(3,0,1,2) tmp6 = tmp6 + tmp5 del tmp2, tmp3, tmp4, tmp5 tmp1 = sf.pt_ifft2c(tmp6) #tmp=csmConj[i,:,:,:] #tmp=tmp.repeat(nbasis,1,1,1).to(gpu) tmp2 = sf.pt_cpx_multipy(csmConj[i].repeat(nbasis, 1, 1, 1), tmp1) atbv = atbv + tmp2 #tmp6=torch.zeros(nbasis,nx,nx,2) #tmp6=tmp6.to(gpu) tmp6 = tmp6.fill_(0) del tmp1, tmp2 x = torch.reshape(x, (nbasis, nx * nx * 2)) #x=W*x #reg=torch.mm(sT,x) #reg=torch.reshape(reg,(nbasis,nx,nx,2)) #atbv=atbv+reg del x return atbv
def forward(self, x): #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0) tmp5 = torch.FloatTensor(self.nch, self.NF, self.nx * self.nx * 2).fill_(0) x = torch.reshape(x, (self.nbasis, self.nx, self.nx, 2)) for i in range(self.nch): tmp2 = sf.pt_fft2c( sf.pt_cpx_multipy(x, self.csmT[i].repeat(self.nbasis, 1, 1, 1))) tmp2 = torch.reshape(tmp2, (self.nbasis, self.NX)) for j in range(self.NF): tmp3 = self.mask[j].repeat(nbasis, 1) * tmp2 tmp3 = tmp3.T @ self.VT[:, k].unsqueeze(1) tmp5[i, j] = tmp3.T return tmp5
def forward(self, x, mask, v1): #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0) nbas = x.shape[0] m_res = ss.maskV(mask.cuda(), v1) m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2) m_res = torch.reshape(m_res, (self.NF, nbasis, nx * nx * 2)) tmp5 = torch.cuda.FloatTensor(self.nch, self.NF, self.nx * self.nx * 2).fill_(0) x = torch.reshape(x, (nbas, self.nx, self.nx, 2)) for i in range(nch): tmp2 = sf.pt_fft2c( sf.pt_cpx_multipy(x, self.csmT[i].repeat(nbas, 1, 1, 1))) tmp2 = torch.reshape(tmp2, (nbas, self.NX)) tmp2 = tmp2.repeat(self.NF, 1, 1) * m_res tmp5[i] = tmp2.sum(axis=1) del tmp2, x return tmp5.cpu()
def AtAUV(x, csmT, csmConj, maskT): atbv = torch.cuda.FloatTensor(nbasis, nx, nx, 2).fill_(0) for i in range(nch): tmp2 = sf.pt_fft2c( sf.pt_cpx_multipy(x, csmT[i].repeat(nbasis, 1, 1, 1))) tmp2 = torch.reshape(tmp2, (nbasis, nx * nx * 2)) tmp2 = tmp2.repeat(nbasis, 1, 1) * maskT tmp = tmp2.sum(axis=1) tmp = torch.reshape(tmp, (nbasis, nx, nx, 2)) tmp2 = sf.pt_cpx_multipy(csmConj[i].repeat(nbasis, 1, 1, 1), sf.pt_ifft2c(tmp)) atbv = atbv + tmp2 del tmp2 x = torch.reshape(x, (nbasis, nx * nx * 2)) x = W * x reg = torch.mm(sT, x) reg = torch.reshape(reg, (nbasis, nx, nx, 2)) atbv = atbv + reg return atbv
def forward(self, x, mask, Vv): #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0) nf = mask.size(0) nbas = x.shape[0] mask = torch.reshape(mask, (nf, nx, nx)) mask = mask.unsqueeze(3) tmp5 = torch.FloatTensor(self.nch, nf, nx, nx, 2).fill_(0) #tmp4=torch.FloatTensor(nf,nbas,nx*nx*2).fill_(0) x = torch.reshape(x, (nbas, nx * nx * 2)) uv = Vv @ x uv = torch.reshape(uv, (nf, nx, nx, 2)) for i in range(nch): tmp2 = sf.pt_fft2c( sf.pt_cpx_multipy(uv, csmT[i].repeat(nf, 1, 1, 1))) tmp5[i] = tmp2 * mask.repeat(1, 1, 1, 2) del tmp2, x return tmp5.cpu()
def pt_A(orgT,csmT,maskT): tmp=sf.pt_cpx_multipy(orgT,csmT) tmp=sf.pt_fft2c(tmp) bT=maskT*tmp return bT