def test_butterfly_to_base4(self): batch_size = 10 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for increasing_stride in [True, False]: for init in ['randn', 'ortho', 'identity']: for nblocks in [1, 2, 3]: b = Butterfly(in_size, out_size, True, complex, increasing_stride, init, nblocks=nblocks).to(device) dtype = torch.float32 if not complex else torch.complex64 input = torch.randn(batch_size, in_size, dtype=dtype, device=device) output = b(input) b4 = b.to_base4() output_base4 = b4(input) self.assertTrue( torch.allclose(output, output_base4, self.rtol, self.atol), (output.shape, device, (in_size, out_size), complex, init, nblocks))
def test_fft_init(self): batch_size = 10 n = 16 input = torch.randn(batch_size, n, dtype=torch.complex64) br = torch_butterfly.permutation.bitreversal_permutation( n, pytorch_format=True) for increasing_stride in [True, False]: for nblocks in [1, 2, 3]: with torch.no_grad(): out_torch = torch.fft.fft(input, norm='ortho') b = Butterfly(n, n, False, complex=True, increasing_stride=increasing_stride, init='fft_no_br', nblocks=nblocks) out = b(input[..., br]) if increasing_stride else b(input)[..., br] self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol)) out_torch = torch.fft.ifft(input, norm='ortho') b = Butterfly(n, n, False, complex=True, increasing_stride=increasing_stride, init='ifft_no_br', nblocks=nblocks) out = b(input[..., br]) if increasing_stride else b(input)[..., br] self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def test_transpose_conjugate_multiply(self): n = 16 for complex in [False, True]: for increasing_stride in [True, False]: for nblocks in [1, 2, 3]: b = Butterfly(n, n, False, complex, increasing_stride, nblocks=nblocks) dtype = torch.float32 if not complex else torch.complex64 input = torch.eye(n, dtype=dtype) matrix = b(input).t() matrix_t = b.forward(input, transpose=True).t() matrix_conj = b.forward(input, conjugate=True).t() matrix_t_conj = b.forward(input, transpose=True, conjugate=True).t() self.assertTrue( torch.allclose(matrix.t(), matrix_t, self.rtol, self.atol), (complex, increasing_stride, nblocks)) self.assertTrue( torch.allclose(matrix.conj(), matrix_conj, self.rtol, self.atol), (complex, increasing_stride, nblocks)) self.assertTrue( torch.allclose(matrix.t().conj(), matrix_t_conj, self.rtol, self.atol), (complex, increasing_stride, nblocks))
def butterfly_product(butterfly1: Butterfly, butterfly2: Butterfly) -> Butterfly: """ Combine product of two butterfly matrices into one Butterfly. """ assert butterfly1.bias is None and butterfly2.bias is None assert butterfly1.complex == butterfly2.complex assert butterfly1.nstacks == butterfly2.nstacks assert butterfly1.log_n == butterfly2.log_n b1_end_increasing_stride = butterfly1.increasing_stride != ( butterfly1.nblocks % 2 == 1) if b1_end_increasing_stride != butterfly2.increasing_stride: # Need to insert an Identity block identity = Butterfly(butterfly1.in_size, butterfly1.out_size, bias=False, complex=butterfly1.complex, increasing_stride=b1_end_increasing_stride, init='identity') butterfly1 = butterfly_product(butterfly1, identity) b = Butterfly(1 << butterfly1.log_n, 1 << butterfly1.log_n, bias=False, complex=butterfly1.complex, increasing_stride=butterfly1.increasing_stride, nblocks=butterfly1.nblocks + butterfly2.nblocks).to( butterfly1.twiddle.device) b.in_size = butterfly1.in_size b.out_size = butterfly2.out_size with torch.no_grad(): # Don't need view_as_complex here since all the twiddles are stored in real. b.twiddle.copy_( torch.cat((butterfly1.twiddle, butterfly2.twiddle), dim=1)) return b
def test_autograd(self): """Check if autograd works (especially for complex), by trying to match a 4x4 matrix. """ size = 4 niters = 10000 true_model = nn.Linear(size, size, bias=False) x = torch.eye(size) with torch.no_grad(): y = true_model(x) for complex in [False, True]: if complex: model = nn.Sequential( torch_butterfly.complex_utils.Real2Complex(), Butterfly(size, size, bias=False, complex=complex), torch_butterfly.complex_utils.Complex2Real(), ) else: model = Butterfly(size, size, bias=False, complex=complex) with torch.no_grad(): inital_loss = F.mse_loss(model(x), y) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for i in range(niters): out = model(x) loss = F.mse_loss(out, y) optimizer.zero_grad() loss.backward() optimizer.step() # At least loss should decrease # print(inital_loss, loss) self.assertTrue(loss.item() < inital_loss.item())
def test_fft2d_init(self): batch_size = 10 in_channels = 3 out_channels = 4 n1, n2 = 16, 32 input = torch.randn(batch_size, in_channels, n2, n1) for kernel_size1 in [1, 3, 5, 7]: for kernel_size2 in [1, 3, 5, 7]: padding1 = (kernel_size1 - 1) // 2 padding2 = (kernel_size2 - 1) // 2 conv = nn.Conv2d(in_channels, out_channels, (kernel_size2, kernel_size1), padding=(padding2, padding1), padding_mode='circular', bias=False) out_torch = conv(input) weight = conv.weight w = F.pad(weight.flip(dims=(-1, )), (0, n1 - kernel_size1)).roll(-padding1, dims=-1) w = F.pad(w.flip(dims=(-2, )), (0, 0, 0, n2 - kernel_size2)).roll(-padding2, dims=-2) increasing_strides = [False, False, True] inits = ['fft_no_br', 'fft_no_br', 'ifft_no_br'] for nblocks in [1, 2, 3]: Kd, K1, K2 = [ TensorProduct( Butterfly(n1, n1, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks), Butterfly(n2, n2, bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks)) for incstride, i in zip(increasing_strides, inits) ] with torch.no_grad(): Kd.map1 *= math.sqrt(n1) Kd.map2 *= math.sqrt(n2) out = K2( complex_matmul( K1(real2complex(input)).permute(2, 3, 0, 1), Kd(real2complex(w)).permute(2, 3, 1, 0)).permute( 2, 3, 0, 1)).real self.assertTrue( torch.allclose(out, out_torch, self.rtol, self.atol))
def test_butterfly_imul(self): batch_size = 10 device = 'cpu' for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for increasing_stride in [True, False]: for init in ['randn', 'ortho', 'identity']: for nblocks in [1, 2, 3]: for scale in [0.13, 2.75]: b = Butterfly(in_size, out_size, False, complex, increasing_stride, init, nblocks=nblocks).to(device) dtype = torch.float32 if not complex else torch.complex64 input = torch.randn(batch_size, in_size, dtype=dtype, device=device) output = b(input) with torch.no_grad(): b *= scale output_scaled = b(input) self.assertTrue( torch.allclose(output * scale, output_scaled, self.rtol, self.atol), (output.shape, device, (in_size, out_size), complex, init, nblocks))
def perm2butterfly_slow(v: Union[np.ndarray, torch.Tensor], complex: bool = False, increasing_stride: bool = False) -> Butterfly: """ Convert a permutation to a Butterfly that performs the same permutation. This implementation is slower but follows the proofs in Appendix G more closely. Parameter: v: a permutation, stored as a vector, in left-multiplication format. (i.e., applying v to a vector x is equivalent to x[p]) complex: whether the Butterfly is complex or real. increasing_stride: whether the returned Butterfly should have increasing_stride=False or True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6. Return: b: a Butterfly that performs the same permutation as v. """ if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() n = len(v) log_n = int(math.ceil(math.log2(n))) if n < 1 << log_n: # Pad permutation to the next power-of-2 size v = np.concatenate([v, np.arange(n, 1 << log_n)]) if increasing_stride: # Follow proof of Lemma G.6 br = bitreversal_permutation(1 << log_n) b = perm2butterfly_slow(br[v[br]], complex=complex, increasing_stride=False) b.increasing_stride = True br_half = bitreversal_permutation((1 << log_n) // 2, pytorch_format=True) with torch.no_grad(): b.twiddle.copy_(b.twiddle[:, :, :, br_half]) b.in_size = b.out_size = n return b # modular_balance expects right-multiplication format so we convert the format of v. Rinv_perms, L_vec = modular_balance(invert(v)) L_perms = list(reversed(modular_balanced_to_butterfly_factor(L_vec))) R_perms = [ perm_vec_to_mat(invert(p), left=True) for p in reversed(Rinv_perms) ] # Stored in increasing_stride=True twiddle format. # Need to take transpose because the matrices are in right-multiplication format. L_twiddle = torch.stack([ matrix_to_butterfly_factor(l.T, log_k=i + 1, pytorch_format=True) for i, l in enumerate(L_perms) ]) # Stored in increasing_stride=False twiddle format so we need to flip the order R_twiddle = torch.stack([ matrix_to_butterfly_factor(r, log_k=i + 1, pytorch_format=True) for i, r in enumerate(R_perms) ]).flip([0]) twiddle = torch.stack([R_twiddle, L_twiddle]).unsqueeze(0) b = Butterfly(n, n, bias=False, complex=complex, increasing_stride=False, init=twiddle if not complex else real2complex(twiddle), nblocks=2) return b
def test_butterfly_bmm(self): batch_size = 10 matrix_batch = 3 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for increasing_stride in [True, False]: for nblocks in [1, 2, 3]: # Test shape b_bmm = torch_butterfly.ButterflyBmm( in_size, out_size, matrix_batch, True, complex, increasing_stride, nblocks=nblocks).to(device) dtype = torch.float32 if not complex else torch.complex64 input = torch.randn(batch_size, matrix_batch, in_size, dtype=dtype, device=device) output = b_bmm(input) self.assertTrue( output.shape == (batch_size, matrix_batch, out_size), (output.shape, device, (in_size, out_size), nblocks)) # Check that the result is the same as looping over butterflies output_loop = [] for i in range(matrix_batch): b = Butterfly( in_size, out_size, True, complex, increasing_stride, init=b_bmm.twiddle[i * b_bmm.nstacks:(i + 1) * b_bmm.nstacks], nblocks=nblocks).to(device) with torch.no_grad(): b.bias.copy_(b_bmm.bias[i]) output_loop.append(b(input[:, i])) with torch.no_grad(): output_loop = torch.stack(output_loop, dim=1) self.assertTrue( torch.allclose(output, output_loop), ((output - output_loop).abs().max().item(), output.shape, device, (in_size, out_size), complex))
def butterfly_kronecker(butterfly1: Butterfly, butterfly2: Butterfly) -> Butterfly: """Combine two butterflies of size n1 and n2 into their Kronecker product of size n1 * n2. They must both have increasing_stride=True or increasing_stride=False. If butterfly1 or butterfly2 has padding, then the kronecker product (after flattening input) will not produce the same result unless the input is padding in the same way before flattening. Only support nstacks==1, nblocks==1 for now. """ assert butterfly1.increasing_stride == butterfly2.increasing_stride assert butterfly1.complex == butterfly2.complex assert not butterfly1.bias and not butterfly2.bias assert butterfly1.nstacks == 1 and butterfly2.nstacks == 1 assert butterfly1.nblocks == 1 and butterfly2.nblocks == 1 increasing_stride = butterfly1.increasing_stride complex = butterfly1.complex log_n1 = butterfly1.twiddle.shape[2] log_n2 = butterfly2.twiddle.shape[2] log_n = log_n1 + log_n2 n = 1 << log_n twiddle1 = butterfly1.twiddle if not complex else view_as_complex( butterfly1.twiddle) twiddle2 = butterfly2.twiddle if not complex else view_as_complex( butterfly2.twiddle) twiddle1 = twiddle1.repeat(1, 1, 1, 1 << log_n2, 1, 1) twiddle2 = twiddle2.repeat_interleave(1 << log_n1, dim=3) twiddle = (torch.cat( (twiddle1, twiddle2), dim=2) if increasing_stride else torch.cat( (twiddle2, twiddle1), dim=2)) b = Butterfly(n, n, bias=False, complex=complex, increasing_stride=increasing_stride).to(twiddle.device) b.in_size = butterfly1.in_size * butterfly2.in_size b.out_size = butterfly1.out_size * butterfly2.out_size with torch.no_grad(): b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle) b_twiddle.copy_(twiddle) return b
def perm2butterfly(v: Union[np.ndarray, torch.Tensor], complex: bool = False, increasing_stride: bool = False) -> Butterfly: """ Parameter: v: a permutation, stored as a vector, in left-multiplication format. (i.e., applying v to a vector x is equivalent to x[p]) complex: whether the Butterfly is complex or real. increasing_stride: whether the returned Butterfly should have increasing_stride=False or True. False corresponds to Lemma G.3 and True corresponds to Lemma G.6. Return: b: a Butterfly that performs the same permutation as v. """ if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() n = len(v) log_n = int(math.ceil(math.log2(n))) if n < 1 << log_n: # Pad permutation to the next power-of-2 size v = np.concatenate([v, np.arange(n, 1 << log_n)]) if increasing_stride: # Follow proof of Lemma G.6 br = bitreversal_permutation(1 << log_n) b = perm2butterfly(br[v[br]], complex=complex, increasing_stride=False) b.increasing_stride = True br_half = bitreversal_permutation((1 << log_n) // 2, pytorch_format=True) with torch.no_grad(): b.twiddle.copy_(b.twiddle[:, :, :, br_half]) b.in_size = b.out_size = n return b v = v[None] twiddle_right_factors, twiddle_left_factors = [], [] for _ in range(log_n): right_factor, left_factor, v = outer_twiddle_factors(v) twiddle_right_factors.append(right_factor) twiddle_left_factors.append(left_factor) b = Butterfly(n, n, bias=False, complex=complex, increasing_stride=False, nblocks=2) with torch.no_grad(): b_twiddle = b.twiddle if not complex else view_as_complex(b.twiddle) twiddle = torch.stack([ torch.stack(twiddle_right_factors), torch.stack(twiddle_left_factors).flip([0]) ]).unsqueeze(0) b_twiddle.copy_(twiddle if not complex else real2complex(twiddle)) return b
def test_subtwiddle(self): batch_size = 10 n = 16 input_size = 8 for complex in [False, True]: for increasing_stride in [True, False]: for nblocks in [1, 2, 3]: b = Butterfly(n, n, True, complex, increasing_stride, nblocks=nblocks) dtype = torch.float32 if not complex else torch.complex64 input = torch.randn(batch_size, input_size, dtype=dtype) output = b(input, subtwiddle=True) self.assertTrue( output.shape == (batch_size, input_size), (output.shape, n, input_size, complex, nblocks))
def test_butterfly(self): batch_size = 10 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for increasing_stride in [True, False]: for init in ['randn', 'ortho', 'identity']: for nblocks in [1, 2, 3]: b = Butterfly(in_size, out_size, True, complex, increasing_stride, init, nblocks=nblocks).to(device) dtype = torch.float32 if not complex else torch.complex64 input = torch.randn(batch_size, in_size, dtype=dtype, device=device) output = b(input) self.assertTrue( output.shape == (batch_size, out_size), (output.shape, device, (in_size, out_size), complex, init, nblocks)) if init == 'ortho': twiddle = b.twiddle twiddle_np = twiddle.detach().to( 'cpu').numpy() twiddle_np = twiddle_np.reshape(-1, 2, 2) twiddle_norm = np.linalg.norm(twiddle_np, ord=2, axis=(1, 2)) self.assertTrue( np.allclose(twiddle_norm, 1), (twiddle_norm, device, (in_size, out_size), complex, init))
def test_butterfly_bmm_tensorproduct(self): # Just to show how to do TensorProduct (e.g., Conv2d) with ButterflyBmm batch_size = 10 in_channels = 3 out_channels = 6 n1, n2 = 32, 16 dtype = torch.complex64 input = torch.randn(batch_size, in_channels, n2, n1, dtype=dtype) # Generate out_channels x in_channels butterfly matrices and loop over them b1s = [ Butterfly(n1, n1, bias=False, complex=True) for _ in range(out_channels * in_channels) ] b2s = [ Butterfly(n2, n2, bias=False, complex=True) for _ in range(out_channels * in_channels) ] b_tp = [ torch_butterfly.combine.TensorProduct(b1, b2) for b1, b2 in zip(b1s, b2s) ] with torch.no_grad(): outputs = [] for o in range(out_channels): output = [] for i in range(in_channels): index = o * in_channels + i output.append(b_tp[index](input[:, i])) outputs.append(torch.stack(output, dim=1)) out = torch.stack(outputs, dim=1) assert out.shape == (batch_size, out_channels, in_channels, n2, n1) # Use ButterflyBmm instead b1_bmm = torch_butterfly.ButterflyBmm( n1, n1, matrix_batch=out_channels * in_channels, bias=False, complex=True, init=torch.cat([b1.twiddle for b1 in b1s])) b2_bmm = torch_butterfly.ButterflyBmm( n2, n2, matrix_batch=out_channels * in_channels, bias=False, complex=True, init=torch.cat([b2.twiddle for b2 in b2s])) input_reshaped = input.transpose(1, 2).reshape(batch_size, n2, 1, in_channels, n1) input_expanded = input_reshaped.expand(batch_size, n2, out_channels, in_channels, n1) out_bmm = b1_bmm( input_expanded.reshape(batch_size, n2, out_channels * in_channels, n1)) out_bmm = out_bmm.transpose( 1, 3) # (batch_size, n1, out_channels * in_channels, n2) out_bmm = b2_bmm( out_bmm) # (batch_size, n1, out_channels * in_channels, n2) out_bmm = out_bmm.permute( 0, 2, 3, 1) # (batch_size, out_channels * in_channels, n2, n1) out_bmm = out_bmm.reshape(batch_size, out_channels, in_channels, n2, n1) self.assertTrue(torch.allclose(out_bmm, out))
def __init__(self, in_size, in_ch, out_ch, kernel_size, complex=True, init='ortho', nblocks=1, base=2, zero_pad=True): super().__init__() self.in_size = in_size self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.complex = complex assert init in ['ortho', 'fft'] if init == 'fft': assert self.complex, 'fft init requires complex=True' self.init = init self.nblocks = nblocks assert base in [2, 4] self.base = base self.zero_pad = zero_pad if isinstance(self.in_size, int): self.in_size = (self.in_size, self.in_size) if isinstance(self.kernel_size, int): self.kernel_size = (self.kernel_size, self.kernel_size) self.padding = (self.kernel_size[0] - 1) // 2, (self.kernel_size[1] - 1) // 2 # Just to use nn.Conv2d's initialization self.weight = nn.Parameter( nn.Conv2d(self.in_ch, self.out_ch, self.kernel_size, padding=self.padding, bias=False).weight.flip([-1, -2])) increasing_strides = [False, False, True] inits = ['ortho'] * 3 if self.init == 'ortho' else [ 'fft_no_br', 'fft_no_br', 'ifft_no_br' ] self.Kd, self.K1, self.K2 = [ TensorProduct( Butterfly(self.in_size[-1], self.in_size[-1], bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks), Butterfly(self.in_size[-2], self.in_size[-2], bias=False, complex=complex, increasing_stride=incstride, init=i, nblocks=nblocks)) for incstride, i in zip(increasing_strides, inits) ] with torch.no_grad(): self.Kd.map1 *= math.sqrt(self.in_size[-1]) self.Kd.map2 *= math.sqrt(self.in_size[-2]) if self.zero_pad and self.complex: # Instead of zero-padding and calling weight.roll(-self.padding[-1], dims=-1) and # weight.roll(-self.padding[-2], dims=-2), we multiply self.Kd by complex exponential # instead, using the Shift theorem. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Shift_theorem with torch.no_grad(): n1, n2 = self.Kd.map1.n, self.Kd.map2.n device = self.Kd.map1.twiddle.device br1 = bitreversal_permutation(n1, pytorch_format=True).to(device) br2 = bitreversal_permutation(n2, pytorch_format=True).to(device) diagonal1 = torch.exp(1j * 2 * math.pi / n1 * self.padding[-1] * torch.arange(n1, device=device))[br1] diagonal2 = torch.exp(1j * 2 * math.pi / n2 * self.padding[-2] * torch.arange(n2, device=device))[br2] # We multiply the 1st block instead of the last block (only the first block is not # the identity if init=fft). This seems to perform a tiny bit better. # If init=ortho, this won't correspond exactly to rolling the weight. self.Kd.map1.twiddle[:, 0, -1, :, 0, :] *= diagonal1[::2].unsqueeze(-1) self.Kd.map1.twiddle[:, 0, -1, :, 1, :] *= diagonal1[1::2].unsqueeze(-1) self.Kd.map2.twiddle[:, 0, -1, :, 0, :] *= diagonal2[::2].unsqueeze(-1) self.Kd.map2.twiddle[:, 0, -1, :, 1, :] *= diagonal2[1::2].unsqueeze(-1) if base == 4: self.Kd.map1, self.Kd.map2 = self.Kd.map1.to_base4( ), self.Kd.map2.to_base4() self.K1.map1, self.K1.map2 = self.K1.map1.to_base4( ), self.K1.map2.to_base4() self.K2.map1, self.K2.map2 = self.K2.map1.to_base4( ), self.K2.map2.to_base4() if complex: self.Kd = nn.Sequential(Real2Complex(), self.Kd) self.K1 = nn.Sequential(Real2Complex(), self.K1) self.K2 = nn.Sequential(self.K2, Complex2Real())
def __init__(self, in_size, in_channels, out_channels, arch_init='ortho', weight_init=nn.init.kaiming_normal_, kmatrix_depth=1, base=2, max_kernel_size=1, padding=None, stride=1, arch_shape=None, weight=None, global_biasing='additive', channel_gating='complete', perturb=0.0, crop_init=slice(0), dilation_init=1, padding_mode='circular', bias=None, checkpoint=False, fourier_position=-1, _swap=False): ''' Args: in_size: input size in_channels: number of input channels out_channels: number of output_channels arch_init: 'ortho' or $OPTYPE (e.g. 'skip') or $OPTYPE'_'$KERNELSIZE (e.g. 'conv_3x3') weight_init: function that initializes weight tensor kmatrix_depth: depth of each kmatrix base: base to use for kmatrix (must be 2 or 4) max_kernel_size: maximum kernel size padding: determines padding; by default sets padding according to arch_init stride: governs subsampling arch_shape: architecture that determines the output shape; uses arch_init by default weight: model weights global_biasing: 'additive' or 'interp' or False channel_gating: 'complete' or 'interp' or False perturb: scale of perturbation to arch params crop_init: input slice(s) to crop dilation_init: kernel dilation at initialization padding_mode: 'circular' or 'zeros'; for 'zeros' will adjust in_size as needed bias: optional bias parameter checkpoint: apply checkpointing to kmatrix operations fourier_position: where to put each Fourier matrix when warm starting; -1 applies it last ''' if not _swap: # '_swap' variable allows for fast re-initialization of a module; useful for computing metrics super(XD, self).__init__() self._init_args = (in_size, in_channels, out_channels) self._init_kwargs = { 'arch_shape': arch_init, 'padding': padding, 'crop_init': crop_init, 'dilation_init': dilation_init, 'padding_mode': padding_mode, 'checkpoint': checkpoint, 'fourier_position': fourier_position } assert base in {2, 4}, "'base' must be 2 or 4" assert global_biasing in {'additive', 'interp', False}, "invalid value for 'global_biasing'" assert channel_gating in {'complete', 'interp', False}, "invalid value for 'channel_gating'" self.checkpoint = checkpoint self.base = base self.chan = (out_channels, in_channels) self.depth = int2tuple(kmatrix_depth, length=3) self.dims = 2 if type(in_size) == int else len(in_size) in_size = int2tuple(in_size, length=self.dims) if padding_mode == 'zeros': # increases effective input size if required due to zero-padding padding = int2tuple(0 if padding is None else padding, length=self.dims) in_size = tuple(n + 2 * p for n, p in zip(in_size, padding)) self.zero_pad = tuple(sum(([p, p] for p in padding), [])) padding = [0] * self.dims else: self.zero_pad = () self.in_size = tuple(2**math.ceil(math.log2(n)) for n in in_size) crop_init = int2tuple(crop_init, length=self.dims) dilation_init = tuple( reversed(int2tuple(dilation_init, length=self.dims))) self.max_kernel_size, kd_init, skips, fourier_init, diagonal_init, self.unpadding = self._parse_init( arch_init, max_kernel_size, padding, arch_shape, dilation_init, _swap) zeroL = diagonal_init and global_biasing == 'additive' self.nd = tuple(reversed(self.in_size)) self.kd = tuple(reversed(self.max_kernel_size)) self.pd = tuple(k // 2 for k in self.kd) self.stride = int2tuple(stride, length=self.dims) if self.dims > 3: assert all( s == 1 for s in self.stride), "must have stride 1 if using >3 dims" self.subsample = nn.Sequential( ) # TODO: handle stride>1 for >3 dimensional XD-op else: self.subsample = AvgPool(self.dims)(kernel_size=[1] * self.dims, stride=self.stride) if not _swap: self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, *self.max_kernel_size)) weight_init(self.weight) if not weight is None: if type(weight ) == nn.Parameter and self.weight.shape == weight.shape: self.weight = weight else: self._offset_insert(self.weight.data, weight.data.to(self.weight.device)) self.bias = nn.Parameter(bias) if type(bias) == torch.Tensor else bias channels = min(self.chan) inoff, outoff = int(0.5 * (in_channels - channels)), int( 0.5 * (out_channels - channels)) if not _swap: self.register_buffer('diag', None, persistent=False) self.register_buffer('kron', None, persistent=False) self.register_buffer('_one', self.r2c(torch.ones(1))) self.register_buffer('_1', self.r2c(torch.ones(self.chan))) self.register_buffer('_I', self.r2c(torch.zeros(self.chan))) self._I[outoff:outoff + channels, inoff:inoff + channels] = torch.eye(channels) for (kmatrix_name, diags), depth, fpos in zip( [ ('K', [self.diag_K(n, s) for n, s in zip(self.nd, skips)]), # handles strides ('L', [ torch.zeros(n) if zeroL else self.diag_L(n, k) for n, k in zip(self.nd, kd_init) ]), # handles kernel size limits ('M', [self.diag_M(n, c) for n, c in zip(self.nd, crop_init)]) ], # handles input cropping self.depth, int2tuple(fourier_position, length=3)): if _swap: kmatrix = getattr(self, kmatrix_name) else: kmatrix_kwargs = { 'bias': False, 'increasing_stride': kmatrix_name == 'K', 'complex': True, 'init': 'identity' if fourier_init else arch_init, 'nblocks': depth, } kmatrix = TensorProduct(*(Butterfly(n, n, **kmatrix_kwargs) for n in self.nd)) if fourier_init: fourier_kmatrix = self.get_fourier( kmatrix_name, *self.nd, diags=[ self._perturb( diag if d == 1 else torch.ones(diag.shape), perturb) for d, diag in zip(dilation_init, diags) ]) if kmatrix_name == 'L' and any(d > 1 for d in dilation_init): fpos = max(2, depth + fpos if fpos < 0 else fpos) for dim, d, k, n in zip(range(1, self.dims + 1), dilation_init, self.kd, self.nd): if kmatrix_name == 'L' and d > 1: # handles initialization of middle K-matrix for the case of dilated convs; requires kmatrix_depth >= 3 assert depth >= 3, "using dilation > 1 requires depth at least (1, 3, 1)" kmatrix.getmap( dim).twiddle.data[:, :2] = diagonal_butterfly( perm2butterfly(self._atrous_permutation( n, k, d), complex=True), diags[dim - 1], diag_first=True).twiddle.data.to( kmatrix.device()) kmatrix.getmap(dim).twiddle.data[ 0, fpos] = fourier_kmatrix.getmap(dim).twiddle.data[ 0, 0].to(kmatrix.device()) if base == 4: for dim in range(1, self.dims + 1): kmatrix.setmap(dim, kmatrix.getmap(dim).to_base4()) setattr(self, kmatrix_name, kmatrix) self.global_biasing = global_biasing filt = self._offset_insert( torch.zeros(1, 1, *self.max_kernel_size), torch.ones(1, 1, *kd_init) / np.prod(kd_init) if 'pool' in arch_init else torch.ones(1, 1, *[1] * self.dims)) if self.global_biasing == 'additive': if diagonal_init: L = self.get_fourier('L', *self.nd, diags=[ self.diag_L(n, k) for n, k in zip(self.nd, kd_init) ]) b = L(self.r2c(self._circular_pad(filt))) else: b = self.r2c(torch.zeros(1, 1, *self.in_size)) elif self.global_biasing == 'interp': if diagonal_init: b = self.r2c(torch.cat((torch.ones(1), filt.flatten()))) else: b = self.r2c(torch.zeros(1 + np.prod(self.max_kernel_size))) else: b = self.r2c(torch.Tensor(0)) if _swap: self.b.data = b.to(self.b.device) else: self.register_parameter('b', nn.Parameter(b)) self.channel_gating = channel_gating if self.channel_gating == 'complete': if diagonal_init: C = self.r2c(torch.zeros(self.chan)) C[outoff:outoff + channels, inoff:inoff + channels] = torch.eye(channels) else: C = self.r2c(torch.ones(self.chan)) elif self.channel_gating == 'interp': C = self.r2c(torch.Tensor([float(diagonal_init)])) else: C = self.r2c(torch.Tensor(0)) if _swap: self.C.data = C.to(self.C.device) else: self.register_parameter('C', nn.Parameter(C)) self.to(self.device())