示例#1
0
    def test_complex_matmul(self):
        """Check that our index_last_dim backward is also correct for real input
        """
        bs = (3, 5)
        for device in ['cpu', 'cuda']:
            X = torch.randn(*bs, 128, 16, dtype=torch.complex64, device=device, requires_grad=True)
            Y = torch.randn(*bs, 16, 32, dtype=torch.complex64, device=device, requires_grad=True)
            prod = complex_matmul(X, Y)
            prod_sum = complex_mul(X.unsqueeze(-1), Y.unsqueeze(-3)).sum(dim=-2)
            self.assertTrue(torch.allclose(prod, prod_sum, self.rtol, self.atol))
            g = torch.randn_like(prod)
            grad_X, grad_Y = torch.autograd.grad(prod, (X, Y), g)
            grad_X_sum, grad_Y_sum = torch.autograd.grad(prod_sum, (X, Y), g)
            self.assertTrue(torch.allclose(grad_X, grad_X_sum, self.rtol, self.atol))
            self.assertTrue(torch.allclose(grad_Y, grad_Y_sum, self.rtol, self.atol))

            X = torch.randn(5, 3, 32, 32, dtype=torch.complex64, device=device, requires_grad=True)
            Y = torch.randn(6, 3, 32, 32, dtype=torch.complex64, device=device, requires_grad=True)
            prod = complex_matmul(X.permute(2, 3, 0, 1), Y.permute(2, 3, 1, 0)).permute(2, 3, 0, 1)
            prod_sum = complex_mul(X.unsqueeze(1), Y).sum(dim=2)
            self.assertTrue(torch.allclose(prod, prod_sum, self.rtol, self.atol))
            g = torch.randn_like(prod)
            grad_X, grad_Y = torch.autograd.grad(prod, (X, Y), g)
            grad_X_sum, grad_Y_sum = torch.autograd.grad(prod_sum, (X, Y), g)
            self.assertTrue(torch.allclose(grad_X, grad_X_sum, self.rtol, self.atol))
            self.assertTrue(torch.allclose(grad_Y, grad_Y_sum, self.rtol, self.atol))
示例#2
0
 def test_fft2d_init(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     n1, n2 = 16, 32
     input = torch.randn(batch_size, in_channels, n2, n1)
     for kernel_size1 in [1, 3, 5, 7]:
         for kernel_size2 in [1, 3, 5, 7]:
             padding1 = (kernel_size1 - 1) // 2
             padding2 = (kernel_size2 - 1) // 2
             conv = nn.Conv2d(in_channels,
                              out_channels, (kernel_size2, kernel_size1),
                              padding=(padding2, padding1),
                              padding_mode='circular',
                              bias=False)
             out_torch = conv(input)
             weight = conv.weight
             w = F.pad(weight.flip(dims=(-1, )),
                       (0, n1 - kernel_size1)).roll(-padding1, dims=-1)
             w = F.pad(w.flip(dims=(-2, )),
                       (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                          dims=-2)
             increasing_strides = [False, False, True]
             inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br']
             for nblocks in [1, 2, 3]:
                 Kd, K1, K2 = [
                     TensorProduct(
                         Butterfly(n1,
                                   n1,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks),
                         Butterfly(n2,
                                   n2,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks))
                     for incstride, i in zip(increasing_strides, inits)
                 ]
                 with torch.no_grad():
                     Kd.map1 *= math.sqrt(n1)
                     Kd.map2 *= math.sqrt(n2)
                 out = K2(
                     complex_matmul(
                         K1(real2complex(input)).permute(2, 3, 0, 1),
                         Kd(real2complex(w)).permute(2, 3, 1, 0)).permute(
                             2, 3, 0, 1)).real
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
示例#3
0
 def forward(self, x):
     # (batch, in_ch, h, w)
     x_f = self.K1(x)
     # (out_ch, in_ch, h, w)
     # w_f = self.Kd(self.weight) * math.sqrt(self.in_size[0] * self.in_size[1])
     # w_f = self.Kd(self.weight)
     w_f = self.Kd(self.weight)
     # prod = (x_f.unsqueeze(1) * w_f).sum(dim=2)
     prod = complex_matmul(x_f.permute(2, 3, 0, 1),
                           w_f.permute(2, 3, 1, 0)).permute(2, 3, 0, 1)
     out = self.K2(prod)
     return out
示例#4
0
 def forward(self, x):
     w = F.pad(self.weight,
               (0, self.in_size[-1] - self.kernel_size[-1])).roll(
                   -self.padding[-1], dims=-1)
     w = F.pad(w, (0, 0, 0, self.in_size[-2] - self.kernel_size[-2])).roll(
         -self.padding[-2], dims=-2)
     # (batch, in_ch, h, w)
     x_f = self.K1(x)
     # (out_ch, in_ch, h, w)
     w_f = self.Kd(w)
     # prod = (x_f.unsqueeze(1) * w_f).sum(dim=2)
     prod = complex_matmul(x_f.permute(2, 3, 0, 1),
                           w_f.permute(2, 3, 1, 0)).permute(2, 3, 0, 1)
     out = self.K2(prod)
     return out
示例#5
0
文件: xd.py 项目: nick11roberts/XD
    def forward(self, x, weight=None):

        x = F.pad(x, self.zero_pad)

        pad, unpad = [], [slice(None), slice(None)]
        for xn, n, (a, b) in zip(x.shape[2:], self.in_size, self.unpadding):
            p1 = (n - xn) // 2
            p2 = (n - xn) // 2
            p1 += p1 + p2 < n - xn
            pad = [p1, p2] + pad
            unpad.append(slice(a + p1, b - p2))
        x = F.pad(x, pad)

        x = self._checkpoint(self.M, self.r2c(x))
        diag = self._diag(weight=weight)
        x = complex_matmul(x.permute(*range(2, 2 + self.dims), 0, 1),
                           diag.permute(*range(2, 2 + self.dims), 1,
                                        0)).permute(-2, -1, *range(self.dims))
        x = self.c2r(self._checkpoint(self.K, x))
        x = self.subsample(x[unpad])
        if self.bias is None:
            return x
        return x + self.bias.reshape(1, *self.bias.shape, *[1] * self.dims)