def forward(self, x, mask, v1): #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0) nbas = x.shape[0] m_res = ss.maskV(mask.cuda(), v1) m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2) m_res = torch.reshape(m_res, (self.NF, nbasis, nx * nx * 2)) tmp5 = torch.cuda.FloatTensor(self.nch, self.NF, self.nx * self.nx * 2).fill_(0) x = torch.reshape(x, (nbas, self.nx, self.nx, 2)) for i in range(nch): tmp2 = sf.pt_fft2c( sf.pt_cpx_multipy(x, self.csmT[i].repeat(nbas, 1, 1, 1))) tmp2 = torch.reshape(tmp2, (nbas, self.NX)) tmp2 = tmp2.repeat(self.NF, 1, 1) * m_res tmp5[i] = tmp2.sum(axis=1) del tmp2, x return tmp5.cpu()
bb_sz = int(NF / batch_sz) B = torch.reshape(kdataT, (nch, int(NF / batch_sz), batch_sz, nx * nx * 2)) maskT = torch.reshape(maskT, (int(NF / batch_sz), batch_sz, nx * nx)) F = AUV(csmT, nbasis, batch_sz) maskT = maskT.to(gpu) sf.tic() for ep1 in range(20): #inx=torch.randint(0,bb_sz,(bb_sz,)) for bat in range(bb_sz): v1 = GV(z1).squeeze(0).squeeze(0) v1 = v1.permute(1, 0) v1 = torch.reshape(v1, (int(NF / batch_sz), batch_sz, nbasis)) #mask_b=maskT[bat] m_res = ss.maskV(maskT[bat], v1[bat]) m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2) m_res = torch.reshape(m_res, (batch_sz, nbasis, nx * nx * 2)) for t1 in range(nbasis): u1[t1] = G(z[t1].unsqueeze(0)) u1 = torch.reshape(u1, (2, nbasis, nx, nx)) u1 = u1.permute(1, 2, 3, 0) u1 = torch.reshape(u1, (nbasis, nx * nx * 2)) b_est = F(u1, m_res.to(gpu)) l1_reg = 0. for param in G.parameters(): l1_reg += param.abs().sum() #loss = criterion(out, target) + l1_regularization loss = 0.5 * (b_est - B[:, bat].cuda() ).pow(2).sum() #+0.1*l1_reg#+(sT@W*u1).pow(2).sum()