Пример #1
0
 def test_butterfly_to_base4(self):
     batch_size = 10
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for init in ['randn', 'ortho', 'identity']:
                         for nblocks in [1, 2, 3]:
                             b = Butterfly(in_size,
                                           out_size,
                                           True,
                                           complex,
                                           increasing_stride,
                                           init,
                                           nblocks=nblocks).to(device)
                             dtype = torch.float32 if not complex else torch.complex64
                             input = torch.randn(batch_size,
                                                 in_size,
                                                 dtype=dtype,
                                                 device=device)
                             output = b(input)
                             b4 = b.to_base4()
                             output_base4 = b4(input)
                             self.assertTrue(
                                 torch.allclose(output, output_base4,
                                                self.rtol, self.atol),
                                 (output.shape, device, (in_size, out_size),
                                  complex, init, nblocks))
Пример #2
0
 def test_fft_init(self):
     batch_size = 10
     n = 16
     input = torch.randn(batch_size, n, dtype=torch.complex64)
     br = torch_butterfly.permutation.bitreversal_permutation(
         n, pytorch_format=True)
     for increasing_stride in [True, False]:
         for nblocks in [1, 2, 3]:
             with torch.no_grad():
                 out_torch = torch.fft.fft(input, norm='ortho')
                 b = Butterfly(n,
                               n,
                               False,
                               complex=True,
                               increasing_stride=increasing_stride,
                               init='fft_no_br',
                               nblocks=nblocks)
                 out = b(input[...,
                               br]) if increasing_stride else b(input)[...,
                                                                       br]
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
                 out_torch = torch.fft.ifft(input, norm='ortho')
                 b = Butterfly(n,
                               n,
                               False,
                               complex=True,
                               increasing_stride=increasing_stride,
                               init='ifft_no_br',
                               nblocks=nblocks)
                 out = b(input[...,
                               br]) if increasing_stride else b(input)[...,
                                                                       br]
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
Пример #3
0
 def test_transpose_conjugate_multiply(self):
     n = 16
     for complex in [False, True]:
         for increasing_stride in [True, False]:
             for nblocks in [1, 2, 3]:
                 b = Butterfly(n,
                               n,
                               False,
                               complex,
                               increasing_stride,
                               nblocks=nblocks)
                 dtype = torch.float32 if not complex else torch.complex64
                 input = torch.eye(n, dtype=dtype)
                 matrix = b(input).t()
                 matrix_t = b.forward(input, transpose=True).t()
                 matrix_conj = b.forward(input, conjugate=True).t()
                 matrix_t_conj = b.forward(input,
                                           transpose=True,
                                           conjugate=True).t()
                 self.assertTrue(
                     torch.allclose(matrix.t(), matrix_t, self.rtol,
                                    self.atol),
                     (complex, increasing_stride, nblocks))
                 self.assertTrue(
                     torch.allclose(matrix.conj(), matrix_conj, self.rtol,
                                    self.atol),
                     (complex, increasing_stride, nblocks))
                 self.assertTrue(
                     torch.allclose(matrix.t().conj(), matrix_t_conj,
                                    self.rtol, self.atol),
                     (complex, increasing_stride, nblocks))
Пример #4
0
def butterfly_product(butterfly1: Butterfly,
                      butterfly2: Butterfly) -> Butterfly:
    """
    Combine product of two butterfly matrices into one Butterfly.
    """
    assert butterfly1.bias is None and butterfly2.bias is None
    assert butterfly1.complex == butterfly2.complex
    assert butterfly1.nstacks == butterfly2.nstacks
    assert butterfly1.log_n == butterfly2.log_n
    b1_end_increasing_stride = butterfly1.increasing_stride != (
        butterfly1.nblocks % 2 == 1)
    if b1_end_increasing_stride != butterfly2.increasing_stride:
        # Need to insert an Identity block
        identity = Butterfly(butterfly1.in_size,
                             butterfly1.out_size,
                             bias=False,
                             complex=butterfly1.complex,
                             increasing_stride=b1_end_increasing_stride,
                             init='identity')
        butterfly1 = butterfly_product(butterfly1, identity)
    b = Butterfly(1 << butterfly1.log_n,
                  1 << butterfly1.log_n,
                  bias=False,
                  complex=butterfly1.complex,
                  increasing_stride=butterfly1.increasing_stride,
                  nblocks=butterfly1.nblocks + butterfly2.nblocks).to(
                      butterfly1.twiddle.device)
    b.in_size = butterfly1.in_size
    b.out_size = butterfly2.out_size
    with torch.no_grad():
        # Don't need view_as_complex here since all the twiddles are stored in real.
        b.twiddle.copy_(
            torch.cat((butterfly1.twiddle, butterfly2.twiddle), dim=1))
    return b
Пример #5
0
 def test_autograd(self):
     """Check if autograd works (especially for complex), by trying to match a 4x4 matrix.
     """
     size = 4
     niters = 10000
     true_model = nn.Linear(size, size, bias=False)
     x = torch.eye(size)
     with torch.no_grad():
         y = true_model(x)
     for complex in [False, True]:
         if complex:
             model = nn.Sequential(
                 torch_butterfly.complex_utils.Real2Complex(),
                 Butterfly(size, size, bias=False, complex=complex),
                 torch_butterfly.complex_utils.Complex2Real(),
             )
         else:
             model = Butterfly(size, size, bias=False, complex=complex)
         with torch.no_grad():
             inital_loss = F.mse_loss(model(x), y)
         optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
         for i in range(niters):
             out = model(x)
             loss = F.mse_loss(out, y)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
         # At least loss should decrease
         # print(inital_loss, loss)
         self.assertTrue(loss.item() < inital_loss.item())
Пример #6
0
 def test_fft2d_init(self):
     batch_size = 10
     in_channels = 3
     out_channels = 4
     n1, n2 = 16, 32
     input = torch.randn(batch_size, in_channels, n2, n1)
     for kernel_size1 in [1, 3, 5, 7]:
         for kernel_size2 in [1, 3, 5, 7]:
             padding1 = (kernel_size1 - 1) // 2
             padding2 = (kernel_size2 - 1) // 2
             conv = nn.Conv2d(in_channels,
                              out_channels, (kernel_size2, kernel_size1),
                              padding=(padding2, padding1),
                              padding_mode='circular',
                              bias=False)
             out_torch = conv(input)
             weight = conv.weight
             w = F.pad(weight.flip(dims=(-1, )),
                       (0, n1 - kernel_size1)).roll(-padding1, dims=-1)
             w = F.pad(w.flip(dims=(-2, )),
                       (0, 0, 0, n2 - kernel_size2)).roll(-padding2,
                                                          dims=-2)
             increasing_strides = [False, False, True]
             inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br']
             for nblocks in [1, 2, 3]:
                 Kd, K1, K2 = [
                     TensorProduct(
                         Butterfly(n1,
                                   n1,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks),
                         Butterfly(n2,
                                   n2,
                                   bias=False,
                                   complex=complex,
                                   increasing_stride=incstride,
                                   init=i,
                                   nblocks=nblocks))
                     for incstride, i in zip(increasing_strides, inits)
                 ]
                 with torch.no_grad():
                     Kd.map1 *= math.sqrt(n1)
                     Kd.map2 *= math.sqrt(n2)
                 out = K2(
                     complex_matmul(
                         K1(real2complex(input)).permute(2, 3, 0, 1),
                         Kd(real2complex(w)).permute(2, 3, 1, 0)).permute(
                             2, 3, 0, 1)).real
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
Пример #7
0
 def test_butterfly_imul(self):
     batch_size = 10
     device = 'cpu'
     for in_size, out_size in [(7, 15), (15, 7)]:
         for complex in [False, True]:
             for increasing_stride in [True, False]:
                 for init in ['randn', 'ortho', 'identity']:
                     for nblocks in [1, 2, 3]:
                         for scale in [0.13, 2.75]:
                             b = Butterfly(in_size,
                                           out_size,
                                           False,
                                           complex,
                                           increasing_stride,
                                           init,
                                           nblocks=nblocks).to(device)
                             dtype = torch.float32 if not complex else torch.complex64
                             input = torch.randn(batch_size,
                                                 in_size,
                                                 dtype=dtype,
                                                 device=device)
                             output = b(input)
                             with torch.no_grad():
                                 b *= scale
                             output_scaled = b(input)
                             self.assertTrue(
                                 torch.allclose(output * scale,
                                                output_scaled, self.rtol,
                                                self.atol),
                                 (output.shape, device, (in_size, out_size),
                                  complex, init, nblocks))
Пример #8
0
def perm2butterfly_slow(v: Union[np.ndarray, torch.Tensor],
                        complex: bool = False,
                        increasing_stride: bool = False) -> Butterfly:
    """
    Convert a permutation to a Butterfly that performs the same permutation.
    This implementation is slower but follows the proofs in Appendix G more closely.
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly_slow(br[v[br]],
                                complex=complex,
                                increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    # modular_balance expects right-multiplication format so we convert the format of v.
    Rinv_perms, L_vec = modular_balance(invert(v))
    L_perms = list(reversed(modular_balanced_to_butterfly_factor(L_vec)))
    R_perms = [
        perm_vec_to_mat(invert(p), left=True) for p in reversed(Rinv_perms)
    ]
    # Stored in increasing_stride=True twiddle format.
    # Need to take transpose because the matrices are in right-multiplication format.
    L_twiddle = torch.stack([
        matrix_to_butterfly_factor(l.T, log_k=i + 1, pytorch_format=True)
        for i, l in enumerate(L_perms)
    ])
    # Stored in increasing_stride=False twiddle format so we need to flip the order
    R_twiddle = torch.stack([
        matrix_to_butterfly_factor(r, log_k=i + 1, pytorch_format=True)
        for i, r in enumerate(R_perms)
    ]).flip([0])
    twiddle = torch.stack([R_twiddle, L_twiddle]).unsqueeze(0)
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  init=twiddle if not complex else real2complex(twiddle),
                  nblocks=2)
    return b
Пример #9
0
 def test_butterfly_bmm(self):
     batch_size = 10
     matrix_batch = 3
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for nblocks in [1, 2, 3]:
                         # Test shape
                         b_bmm = torch_butterfly.ButterflyBmm(
                             in_size,
                             out_size,
                             matrix_batch,
                             True,
                             complex,
                             increasing_stride,
                             nblocks=nblocks).to(device)
                         dtype = torch.float32 if not complex else torch.complex64
                         input = torch.randn(batch_size,
                                             matrix_batch,
                                             in_size,
                                             dtype=dtype,
                                             device=device)
                         output = b_bmm(input)
                         self.assertTrue(
                             output.shape == (batch_size, matrix_batch,
                                              out_size),
                             (output.shape, device,
                              (in_size, out_size), nblocks))
                         # Check that the result is the same as looping over butterflies
                         output_loop = []
                         for i in range(matrix_batch):
                             b = Butterfly(
                                 in_size,
                                 out_size,
                                 True,
                                 complex,
                                 increasing_stride,
                                 init=b_bmm.twiddle[i *
                                                    b_bmm.nstacks:(i + 1) *
                                                    b_bmm.nstacks],
                                 nblocks=nblocks).to(device)
                             with torch.no_grad():
                                 b.bias.copy_(b_bmm.bias[i])
                             output_loop.append(b(input[:, i]))
                         with torch.no_grad():
                             output_loop = torch.stack(output_loop, dim=1)
                         self.assertTrue(
                             torch.allclose(output, output_loop),
                             ((output - output_loop).abs().max().item(),
                              output.shape, device,
                              (in_size, out_size), complex))
Пример #10
0
def butterfly_kronecker(butterfly1: Butterfly,
                        butterfly2: Butterfly) -> Butterfly:
    """Combine two butterflies of size n1 and n2 into their Kronecker product of size n1 * n2.
    They must both have increasing_stride=True or increasing_stride=False.
    If butterfly1 or butterfly2 has padding, then the kronecker product (after flattening input)
    will not produce the same result unless the input is padding in the same way before flattening.

    Only support nstacks==1, nblocks==1 for now.
    """
    assert butterfly1.increasing_stride == butterfly2.increasing_stride
    assert butterfly1.complex == butterfly2.complex
    assert not butterfly1.bias and not butterfly2.bias
    assert butterfly1.nstacks == 1 and butterfly2.nstacks == 1
    assert butterfly1.nblocks == 1 and butterfly2.nblocks == 1
    increasing_stride = butterfly1.increasing_stride
    complex = butterfly1.complex
    log_n1 = butterfly1.twiddle.shape[2]
    log_n2 = butterfly2.twiddle.shape[2]
    log_n = log_n1 + log_n2
    n = 1 << log_n
    twiddle1 = butterfly1.twiddle if not complex else view_as_complex(
        butterfly1.twiddle)
    twiddle2 = butterfly2.twiddle if not complex else view_as_complex(
        butterfly2.twiddle)
    twiddle1 = twiddle1.repeat(1, 1, 1, 1 << log_n2, 1, 1)
    twiddle2 = twiddle2.repeat_interleave(1 << log_n1, dim=3)
    twiddle = (torch.cat(
        (twiddle1, twiddle2), dim=2) if increasing_stride else torch.cat(
            (twiddle2, twiddle1), dim=2))
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=increasing_stride).to(twiddle.device)
    b.in_size = butterfly1.in_size * butterfly2.in_size
    b.out_size = butterfly1.out_size * butterfly2.out_size
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        b_twiddle.copy_(twiddle)
    return b
Пример #11
0
def perm2butterfly(v: Union[np.ndarray, torch.Tensor],
                   complex: bool = False,
                   increasing_stride: bool = False) -> Butterfly:
    """
    Parameter:
        v: a permutation, stored as a vector, in left-multiplication format.
            (i.e., applying v to a vector x is equivalent to x[p])
        complex: whether the Butterfly is complex or real.
        increasing_stride: whether the returned Butterfly should have increasing_stride=False or
            True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6.
    Return:
        b: a Butterfly that performs the same permutation as v.
    """
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    n = len(v)
    log_n = int(math.ceil(math.log2(n)))
    if n < 1 << log_n:  # Pad permutation to the next power-of-2 size
        v = np.concatenate([v, np.arange(n, 1 << log_n)])
    if increasing_stride:  # Follow proof of Lemma G.6
        br = bitreversal_permutation(1 << log_n)
        b = perm2butterfly(br[v[br]], complex=complex, increasing_stride=False)
        b.increasing_stride = True
        br_half = bitreversal_permutation((1 << log_n) // 2,
                                          pytorch_format=True)
        with torch.no_grad():
            b.twiddle.copy_(b.twiddle[:, :, :, br_half])
        b.in_size = b.out_size = n
        return b
    v = v[None]
    twiddle_right_factors, twiddle_left_factors = [], []
    for _ in range(log_n):
        right_factor, left_factor, v = outer_twiddle_factors(v)
        twiddle_right_factors.append(right_factor)
        twiddle_left_factors.append(left_factor)
    b = Butterfly(n,
                  n,
                  bias=False,
                  complex=complex,
                  increasing_stride=False,
                  nblocks=2)
    with torch.no_grad():
        b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle)
        twiddle = torch.stack([
            torch.stack(twiddle_right_factors),
            torch.stack(twiddle_left_factors).flip([0])
        ]).unsqueeze(0)
        b_twiddle.copy_(twiddle if not complex else real2complex(twiddle))
    return b
Пример #12
0
 def test_subtwiddle(self):
     batch_size = 10
     n = 16
     input_size = 8
     for complex in [False, True]:
         for increasing_stride in [True, False]:
             for nblocks in [1, 2, 3]:
                 b = Butterfly(n,
                               n,
                               True,
                               complex,
                               increasing_stride,
                               nblocks=nblocks)
                 dtype = torch.float32 if not complex else torch.complex64
                 input = torch.randn(batch_size, input_size, dtype=dtype)
                 output = b(input, subtwiddle=True)
                 self.assertTrue(
                     output.shape == (batch_size, input_size),
                     (output.shape, n, input_size, complex, nblocks))
Пример #13
0
 def test_butterfly(self):
     batch_size = 10
     for device in ['cpu'
                    ] + ([] if not torch.cuda.is_available() else ['cuda']):
         for in_size, out_size in [(7, 15), (15, 7)]:
             for complex in [False, True]:
                 for increasing_stride in [True, False]:
                     for init in ['randn', 'ortho', 'identity']:
                         for nblocks in [1, 2, 3]:
                             b = Butterfly(in_size,
                                           out_size,
                                           True,
                                           complex,
                                           increasing_stride,
                                           init,
                                           nblocks=nblocks).to(device)
                             dtype = torch.float32 if not complex else torch.complex64
                             input = torch.randn(batch_size,
                                                 in_size,
                                                 dtype=dtype,
                                                 device=device)
                             output = b(input)
                             self.assertTrue(
                                 output.shape == (batch_size, out_size),
                                 (output.shape, device, (in_size, out_size),
                                  complex, init, nblocks))
                             if init == 'ortho':
                                 twiddle = b.twiddle
                                 twiddle_np = twiddle.detach().to(
                                     'cpu').numpy()
                                 twiddle_np = twiddle_np.reshape(-1, 2, 2)
                                 twiddle_norm = np.linalg.norm(twiddle_np,
                                                               ord=2,
                                                               axis=(1, 2))
                                 self.assertTrue(
                                     np.allclose(twiddle_norm, 1),
                                     (twiddle_norm, device,
                                      (in_size, out_size), complex, init))
Пример #14
0
 def test_butterfly_bmm_tensorproduct(self):
     # Just to show how to do TensorProduct (e.g., Conv2d) with ButterflyBmm
     batch_size = 10
     in_channels = 3
     out_channels = 6
     n1, n2 = 32, 16
     dtype = torch.complex64
     input = torch.randn(batch_size, in_channels, n2, n1, dtype=dtype)
     # Generate out_channels x in_channels butterfly matrices and loop over them
     b1s = [
         Butterfly(n1, n1, bias=False, complex=True)
         for _ in range(out_channels * in_channels)
     ]
     b2s = [
         Butterfly(n2, n2, bias=False, complex=True)
         for _ in range(out_channels * in_channels)
     ]
     b_tp = [
         torch_butterfly.combine.TensorProduct(b1, b2)
         for b1, b2 in zip(b1s, b2s)
     ]
     with torch.no_grad():
         outputs = []
         for o in range(out_channels):
             output = []
             for i in range(in_channels):
                 index = o * in_channels + i
                 output.append(b_tp[index](input[:, i]))
             outputs.append(torch.stack(output, dim=1))
         out = torch.stack(outputs, dim=1)
     assert out.shape == (batch_size, out_channels, in_channels, n2, n1)
     # Use ButterflyBmm instead
     b1_bmm = torch_butterfly.ButterflyBmm(
         n1,
         n1,
         matrix_batch=out_channels * in_channels,
         bias=False,
         complex=True,
         init=torch.cat([b1.twiddle for b1 in b1s]))
     b2_bmm = torch_butterfly.ButterflyBmm(
         n2,
         n2,
         matrix_batch=out_channels * in_channels,
         bias=False,
         complex=True,
         init=torch.cat([b2.twiddle for b2 in b2s]))
     input_reshaped = input.transpose(1, 2).reshape(batch_size, n2, 1,
                                                    in_channels, n1)
     input_expanded = input_reshaped.expand(batch_size, n2, out_channels,
                                            in_channels, n1)
     out_bmm = b1_bmm(
         input_expanded.reshape(batch_size, n2, out_channels * in_channels,
                                n1))
     out_bmm = out_bmm.transpose(
         1, 3)  # (batch_size, n1, out_channels * in_channels, n2)
     out_bmm = b2_bmm(
         out_bmm)  # (batch_size, n1, out_channels * in_channels, n2)
     out_bmm = out_bmm.permute(
         0, 2, 3, 1)  # (batch_size, out_channels * in_channels, n2, n1)
     out_bmm = out_bmm.reshape(batch_size, out_channels, in_channels, n2,
                               n1)
     self.assertTrue(torch.allclose(out_bmm, out))
Пример #15
0
    def __init__(self,
                 in_size,
                 in_ch,
                 out_ch,
                 kernel_size,
                 complex=True,
                 init='ortho',
                 nblocks=1,
                 base=2,
                 zero_pad=True):
        super().__init__()
        self.in_size = in_size
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.complex = complex
        assert init in ['ortho', 'fft']
        if init == 'fft':
            assert self.complex, 'fft init requires complex=True'
        self.init = init
        self.nblocks = nblocks
        assert base in [2, 4]
        self.base = base
        self.zero_pad = zero_pad
        if isinstance(self.in_size, int):
            self.in_size = (self.in_size, self.in_size)
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)
        self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] -
                                                        1) // 2
        # Just to use nn.Conv2d's initialization
        self.weight = nn.Parameter(
            nn.Conv2d(self.in_ch,
                      self.out_ch,
                      self.kernel_size,
                      padding=self.padding,
                      bias=False).weight.flip([-1, -2]))

        increasing_strides = [False, False, True]
        inits = ['ortho'] * 3 if self.init == 'ortho' else [
            'fft_no_br', 'fft_no_br', 'ifft_no_br'
        ]
        self.Kd, self.K1, self.K2 = [
            TensorProduct(
                Butterfly(self.in_size[-1],
                          self.in_size[-1],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks),
                Butterfly(self.in_size[-2],
                          self.in_size[-2],
                          bias=False,
                          complex=complex,
                          increasing_stride=incstride,
                          init=i,
                          nblocks=nblocks))
            for incstride, i in zip(increasing_strides, inits)
        ]
        with torch.no_grad():
            self.Kd.map1 *= math.sqrt(self.in_size[-1])
            self.Kd.map2 *= math.sqrt(self.in_size[-2])
        if self.zero_pad and self.complex:
            # Instead of zero-padding and calling weight.roll(-self.padding[-1], dims=-1) and
            # weight.roll(-self.padding[-2], dims=-2), we multiply self.Kd by complex exponential
            # instead, using the Shift theorem.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Shift_theorem
            with torch.no_grad():
                n1, n2 = self.Kd.map1.n, self.Kd.map2.n
                device = self.Kd.map1.twiddle.device
                br1 = bitreversal_permutation(n1,
                                              pytorch_format=True).to(device)
                br2 = bitreversal_permutation(n2,
                                              pytorch_format=True).to(device)
                diagonal1 = torch.exp(1j * 2 * math.pi / n1 *
                                      self.padding[-1] *
                                      torch.arange(n1, device=device))[br1]
                diagonal2 = torch.exp(1j * 2 * math.pi / n2 *
                                      self.padding[-2] *
                                      torch.arange(n2, device=device))[br2]
                # We multiply the 1st block instead of the last block (only the first block is not
                # the identity if init=fft). This seems to perform a tiny bit better.
                # If init=ortho, this won't correspond exactly to rolling the weight.
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal1[::2].unsqueeze(-1)
                self.Kd.map1.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal1[1::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     0, :] *= diagonal2[::2].unsqueeze(-1)
                self.Kd.map2.twiddle[:, 0, -1, :,
                                     1, :] *= diagonal2[1::2].unsqueeze(-1)

        if base == 4:
            self.Kd.map1, self.Kd.map2 = self.Kd.map1.to_base4(
            ), self.Kd.map2.to_base4()
            self.K1.map1, self.K1.map2 = self.K1.map1.to_base4(
            ), self.K1.map2.to_base4()
            self.K2.map1, self.K2.map2 = self.K2.map1.to_base4(
            ), self.K2.map2.to_base4()

        if complex:
            self.Kd = nn.Sequential(Real2Complex(), self.Kd)
            self.K1 = nn.Sequential(Real2Complex(), self.K1)
            self.K2 = nn.Sequential(self.K2, Complex2Real())
Пример #16
0
    def __init__(self,
                 in_size,
                 in_channels,
                 out_channels,
                 arch_init='ortho',
                 weight_init=nn.init.kaiming_normal_,
                 kmatrix_depth=1,
                 base=2,
                 max_kernel_size=1,
                 padding=None,
                 stride=1,
                 arch_shape=None,
                 weight=None,
                 global_biasing='additive',
                 channel_gating='complete',
                 perturb=0.0,
                 crop_init=slice(0),
                 dilation_init=1,
                 padding_mode='circular',
                 bias=None,
                 checkpoint=False,
                 fourier_position=-1,
                 _swap=False):
        '''
        Args:
            in_size: input size
            in_channels: number of input channels
            out_channels: number of output_channels
            arch_init: 'ortho' or $OPTYPE (e.g. 'skip') or $OPTYPE'_'$KERNELSIZE (e.g. 'conv_3x3')
            weight_init: function that initializes weight tensor
            kmatrix_depth: depth of each kmatrix
            base: base to use for kmatrix (must be 2 or 4)
            max_kernel_size: maximum kernel size
            padding: determines padding; by default sets padding according to arch_init 
            stride: governs subsampling
            arch_shape: architecture that determines the output shape; uses arch_init by default
            weight: model weights
            global_biasing: 'additive' or 'interp' or False
            channel_gating: 'complete' or 'interp' or False
            perturb: scale of perturbation to arch params
            crop_init: input slice(s) to crop
            dilation_init: kernel dilation at initialization
            padding_mode: 'circular' or 'zeros'; for 'zeros' will adjust in_size as needed
            bias: optional bias parameter
            checkpoint: apply checkpointing to kmatrix operations
            fourier_position: where to put each Fourier matrix when warm starting; -1 applies it last
        '''

        if not _swap:
            # '_swap' variable allows for fast re-initialization of a module; useful for computing metrics
            super(XD, self).__init__()
            self._init_args = (in_size, in_channels, out_channels)
            self._init_kwargs = {
                'arch_shape': arch_init,
                'padding': padding,
                'crop_init': crop_init,
                'dilation_init': dilation_init,
                'padding_mode': padding_mode,
                'checkpoint': checkpoint,
                'fourier_position': fourier_position
            }
        assert base in {2, 4}, "'base' must be 2 or 4"
        assert global_biasing in {'additive', 'interp',
                                  False}, "invalid value for 'global_biasing'"
        assert channel_gating in {'complete', 'interp',
                                  False}, "invalid value for 'channel_gating'"

        self.checkpoint = checkpoint
        self.base = base
        self.chan = (out_channels, in_channels)
        self.depth = int2tuple(kmatrix_depth, length=3)
        self.dims = 2 if type(in_size) == int else len(in_size)
        in_size = int2tuple(in_size, length=self.dims)
        if padding_mode == 'zeros':
            # increases effective input size if required due to zero-padding
            padding = int2tuple(0 if padding is None else padding,
                                length=self.dims)
            in_size = tuple(n + 2 * p for n, p in zip(in_size, padding))
            self.zero_pad = tuple(sum(([p, p] for p in padding), []))
            padding = [0] * self.dims
        else:
            self.zero_pad = ()
        self.in_size = tuple(2**math.ceil(math.log2(n)) for n in in_size)
        crop_init = int2tuple(crop_init, length=self.dims)
        dilation_init = tuple(
            reversed(int2tuple(dilation_init, length=self.dims)))
        self.max_kernel_size, kd_init, skips, fourier_init, diagonal_init, self.unpadding = self._parse_init(
            arch_init, max_kernel_size, padding, arch_shape, dilation_init,
            _swap)
        zeroL = diagonal_init and global_biasing == 'additive'
        self.nd = tuple(reversed(self.in_size))
        self.kd = tuple(reversed(self.max_kernel_size))
        self.pd = tuple(k // 2 for k in self.kd)
        self.stride = int2tuple(stride, length=self.dims)
        if self.dims > 3:
            assert all(
                s == 1
                for s in self.stride), "must have stride 1 if using >3 dims"
            self.subsample = nn.Sequential(
            )  # TODO: handle stride>1 for >3 dimensional XD-op
        else:
            self.subsample = AvgPool(self.dims)(kernel_size=[1] * self.dims,
                                                stride=self.stride)

        if not _swap:
            self.weight = nn.Parameter(
                torch.Tensor(out_channels, in_channels, *self.max_kernel_size))
            weight_init(self.weight)
        if not weight is None:
            if type(weight
                    ) == nn.Parameter and self.weight.shape == weight.shape:
                self.weight = weight
            else:
                self._offset_insert(self.weight.data,
                                    weight.data.to(self.weight.device))
        self.bias = nn.Parameter(bias) if type(bias) == torch.Tensor else bias

        channels = min(self.chan)
        inoff, outoff = int(0.5 * (in_channels - channels)), int(
            0.5 * (out_channels - channels))
        if not _swap:
            self.register_buffer('diag', None, persistent=False)
            self.register_buffer('kron', None, persistent=False)
            self.register_buffer('_one', self.r2c(torch.ones(1)))
            self.register_buffer('_1', self.r2c(torch.ones(self.chan)))
            self.register_buffer('_I', self.r2c(torch.zeros(self.chan)))
            self._I[outoff:outoff + channels,
                    inoff:inoff + channels] = torch.eye(channels)

        for (kmatrix_name, diags), depth, fpos in zip(
            [
                ('K', [self.diag_K(n, s)
                       for n, s in zip(self.nd, skips)]),  # handles strides
                ('L', [
                    torch.zeros(n) if zeroL else self.diag_L(n, k)
                    for n, k in zip(self.nd, kd_init)
                ]),  # handles kernel size limits
                ('M', [self.diag_M(n, c) for n, c in zip(self.nd, crop_init)])
            ],  # handles input cropping
                self.depth,
                int2tuple(fourier_position, length=3)):
            if _swap:
                kmatrix = getattr(self, kmatrix_name)
            else:
                kmatrix_kwargs = {
                    'bias': False,
                    'increasing_stride': kmatrix_name == 'K',
                    'complex': True,
                    'init': 'identity' if fourier_init else arch_init,
                    'nblocks': depth,
                }
                kmatrix = TensorProduct(*(Butterfly(n, n, **kmatrix_kwargs)
                                          for n in self.nd))
            if fourier_init:
                fourier_kmatrix = self.get_fourier(
                    kmatrix_name,
                    *self.nd,
                    diags=[
                        self._perturb(
                            diag if d == 1 else torch.ones(diag.shape),
                            perturb) for d, diag in zip(dilation_init, diags)
                    ])
                if kmatrix_name == 'L' and any(d > 1 for d in dilation_init):
                    fpos = max(2, depth + fpos if fpos < 0 else fpos)
                for dim, d, k, n in zip(range(1, self.dims + 1), dilation_init,
                                        self.kd, self.nd):
                    if kmatrix_name == 'L' and d > 1:
                        # handles initialization of middle K-matrix for the case of dilated convs; requires kmatrix_depth >= 3
                        assert depth >= 3, "using dilation > 1 requires depth at least (1, 3, 1)"
                        kmatrix.getmap(
                            dim).twiddle.data[:, :2] = diagonal_butterfly(
                                perm2butterfly(self._atrous_permutation(
                                    n, k, d),
                                               complex=True),
                                diags[dim - 1],
                                diag_first=True).twiddle.data.to(
                                    kmatrix.device())
                    kmatrix.getmap(dim).twiddle.data[
                        0, fpos] = fourier_kmatrix.getmap(dim).twiddle.data[
                            0, 0].to(kmatrix.device())
            if base == 4:
                for dim in range(1, self.dims + 1):
                    kmatrix.setmap(dim, kmatrix.getmap(dim).to_base4())
            setattr(self, kmatrix_name, kmatrix)

        self.global_biasing = global_biasing
        filt = self._offset_insert(
            torch.zeros(1, 1, *self.max_kernel_size),
            torch.ones(1, 1, *kd_init) / np.prod(kd_init)
            if 'pool' in arch_init else torch.ones(1, 1, *[1] * self.dims))
        if self.global_biasing == 'additive':
            if diagonal_init:
                L = self.get_fourier('L',
                                     *self.nd,
                                     diags=[
                                         self.diag_L(n, k)
                                         for n, k in zip(self.nd, kd_init)
                                     ])
                b = L(self.r2c(self._circular_pad(filt)))
            else:
                b = self.r2c(torch.zeros(1, 1, *self.in_size))
        elif self.global_biasing == 'interp':
            if diagonal_init:
                b = self.r2c(torch.cat((torch.ones(1), filt.flatten())))
            else:
                b = self.r2c(torch.zeros(1 + np.prod(self.max_kernel_size)))
        else:
            b = self.r2c(torch.Tensor(0))
        if _swap:
            self.b.data = b.to(self.b.device)
        else:
            self.register_parameter('b', nn.Parameter(b))

        self.channel_gating = channel_gating
        if self.channel_gating == 'complete':
            if diagonal_init:
                C = self.r2c(torch.zeros(self.chan))
                C[outoff:outoff + channels,
                  inoff:inoff + channels] = torch.eye(channels)
            else:
                C = self.r2c(torch.ones(self.chan))
        elif self.channel_gating == 'interp':
            C = self.r2c(torch.Tensor([float(diagonal_init)]))
        else:
            C = self.r2c(torch.Tensor(0))
        if _swap:
            self.C.data = C.to(self.C.device)
        else:
            self.register_parameter('C', nn.Parameter(C))

        self.to(self.device())