def forward(self, x, mask, v1):
            #tmp2=torch.FloatTensor(self.nch,self.nbasis,self.nx,self.nx,2).fill_(0)
            nbas = x.shape[0]
            m_res = ss.maskV(mask.cuda(), v1)
            m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2)
            m_res = torch.reshape(m_res, (self.NF, nbasis, nx * nx * 2))
            tmp5 = torch.cuda.FloatTensor(self.nch, self.NF,
                                          self.nx * self.nx * 2).fill_(0)
            x = torch.reshape(x, (nbas, self.nx, self.nx, 2))
            for i in range(nch):
                tmp2 = sf.pt_fft2c(
                    sf.pt_cpx_multipy(x, self.csmT[i].repeat(nbas, 1, 1, 1)))
                tmp2 = torch.reshape(tmp2, (nbas, self.NX))
                tmp2 = tmp2.repeat(self.NF, 1, 1) * m_res
                tmp5[i] = tmp2.sum(axis=1)

            del tmp2, x
            return tmp5.cpu()
Example #2
0
    bb_sz = int(NF / batch_sz)
    B = torch.reshape(kdataT, (nch, int(NF / batch_sz), batch_sz, nx * nx * 2))
    maskT = torch.reshape(maskT, (int(NF / batch_sz), batch_sz, nx * nx))

    F = AUV(csmT, nbasis, batch_sz)
    maskT = maskT.to(gpu)

    sf.tic()
    for ep1 in range(20):
        #inx=torch.randint(0,bb_sz,(bb_sz,))
        for bat in range(bb_sz):
            v1 = GV(z1).squeeze(0).squeeze(0)
            v1 = v1.permute(1, 0)
            v1 = torch.reshape(v1, (int(NF / batch_sz), batch_sz, nbasis))
            #mask_b=maskT[bat]
            m_res = ss.maskV(maskT[bat], v1[bat])
            m_res = m_res.unsqueeze(3).repeat(1, 1, 1, 2)
            m_res = torch.reshape(m_res, (batch_sz, nbasis, nx * nx * 2))
            for t1 in range(nbasis):
                u1[t1] = G(z[t1].unsqueeze(0))

            u1 = torch.reshape(u1, (2, nbasis, nx, nx))
            u1 = u1.permute(1, 2, 3, 0)
            u1 = torch.reshape(u1, (nbasis, nx * nx * 2))
            b_est = F(u1, m_res.to(gpu))
            l1_reg = 0.
            for param in G.parameters():
                l1_reg += param.abs().sum()
            #loss = criterion(out, target) + l1_regularization
            loss = 0.5 * (b_est - B[:, bat].cuda()
                          ).pow(2).sum()  #+0.1*l1_reg#+(sT@W*u1).pow(2).sum()