def fastfood(diag1: torch.Tensor, diag2: torch.Tensor, diag3: torch.Tensor, permutation: torch.Tensor, normalized: bool = False, increasing_stride: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that performs Fastfood multiplication: x -> Diag3 @ H @ Diag2 @ P @ H @ Diag1, where H is the Hadamard matrix and P is a permutation matrix. Parameters: diag1: (n,), where n is a power of 2. diag2: (n,) diag3: (n,) permutation: (n,) normalized: if True, corresponds to the orthogonal Hadamard transform (i.e. multiplied by 1/sqrt(n)) increasing_stride: whether the first Butterfly in the sequence has increasing stride. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == diag3.shape == permutation.shape == (n, ) h1 = hadamard(n, normalized, increasing_stride) h2 = hadamard(n, normalized, not increasing_stride) if not separate_diagonal: h1 = diagonal_butterfly(h1, diag1, diag_first=True) h2 = diagonal_butterfly(h2, diag2, diag_first=True) h2 = diagonal_butterfly(h2, diag3, diag_first=False) return nn.Sequential(h1, FixedPermutation(permutation), h2) else: return nn.Sequential(Diagonal(diagonal_init=diag1), h1, FixedPermutation(permutation), Diagonal(diagonal_init=diag2), h2, Diagonal(diagonal_init=diag3))
def dct(n: int, type: int = 2, normalized: bool = False) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs the DCT. Parameters: n: size of the DCT. Must be a power of 2. type: either 2, 3, or 4. These are the only types supported. See scipy.fft.dct's notes. normalized: if True, corresponds to the orthogonal DCT (see scipy.fft.dct's notes) """ assert type in [2, 3, 4] # Construct the permutation before the FFT: separate the even and odd and then reverse the odd # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1]. perm = torch.arange(n) perm = torch.cat((perm[::2], perm[1::2].flip([0]))) br = bitreversal_permutation(n, pytorch_format=True) postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) / (2 * n)) if type in [2, 4]: b = fft(n, normalized=normalized, br_first=True, with_br_perm=False) if type == 4: even_mul = torch.exp(-1j * math.pi / (2 * n) * (torch.arange(0.0, n, 2) + 0.5)) odd_mul = torch.exp(1j * math.pi / (2 * n) * (torch.arange(1.0, n, 2) + 0.5)) preprocess_diag = torch.stack((even_mul, odd_mul), dim=-1).flatten() # This proprocess_diag is before the permutation. # To move it after the permutation, we have to permute the diagonal b = diagonal_butterfly(b, preprocess_diag[perm[br]], diag_first=True) if normalized: if type in [2, 3]: postprocess_diag[0] /= 2.0 postprocess_diag[1:] /= math.sqrt(2) elif type == 4: postprocess_diag /= math.sqrt(2) b = diagonal_butterfly(b, postprocess_diag, diag_first=False) return nn.Sequential(FixedPermutation(perm[br]), Real2Complex(), b, Complex2Real()) else: assert type == 3 b = ifft(n, normalized=normalized, br_first=False, with_br_perm=False) postprocess_diag[0] /= 2.0 if normalized: postprocess_diag[1:] /= math.sqrt(2) else: # We want iFFT with the scaling of 1.0 instead of 1 / n with torch.no_grad(): b.twiddle *= 2 b = diagonal_butterfly(b, postprocess_diag.conj(), diag_first=True) perm_inverse = invert(perm) return nn.Sequential(Real2Complex(), b, Complex2Real(), FixedPermutation(br[perm_inverse]))
def fourier_diag(*nd, diags=None, inv=False, diag_first=True, with_br_perm=False, **kwargs): '''returns n-dimensional FFT Butterfly matrix multiplied by a diagonal matrix Args: nd: input sizes of each dimension diags: torch.Tensor vectors specifying diagonals; if None uses the identity inv: return inverse FFT diag_first: returns FFT * diagonal; if False returns diagonal * FFT with_br_perm: uses bit-reversal permutation kwargs: passed to torch_butterfly.special.fft Returns: TensorProduct object ''' kwargs['with_br_perm'] = with_br_perm if diags is None: if len(nd) == 1: return TensorProduct( ifft(*nd, **kwargs) if inv else fft(*nd, **kwargs)) if len(nd) == 2: return TensorProduct( ifft2d(*nd, **kwargs) if inv else fft2d(*nd, **kwargs)) diags = [torch.ones(n) for n in nd] assert set(kwargs.keys()).issubset({'normalized', 'br_first', 'with_br_perm'}) and not with_br_perm, \ "invalid kwargs when using diags or >2 dims" func = ifft if inv else fft return TensorProduct( *(diagonal_butterfly(func(n, **kwargs), diag, diag_first=diag_first) for n, diag in zip(nd, diags)))
def hadamard_diagonal(diagonals: torch.Tensor, normalized: bool = False, increasing_stride: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that performs multiplication by H D H D ... H D, where H is the Hadamard matrix and D is a diagonal matrix Parameters: diagonals: (k, n), where k is the number of diagonal matrices and n is the dimension of the Hadamard transform. normalized: if True, corresponds to the orthogonal Hadamard transform (i.e. multiplied by 1/sqrt(n)) increasing_stride: whether the returned Butterfly has increasing stride. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ k, n = diagonals.shape if not separate_diagonal: butterflies = [] for i, diagonal in enumerate(diagonals.unbind()): cur_increasing_stride = increasing_stride != (i % 2 == 1) h = hadamard(n, normalized, cur_increasing_stride) butterflies.append(diagonal_butterfly(h, diagonal, diag_first=True)) return reduce(butterfly_product, butterflies) else: modules = [] for i, diagonal in enumerate(diagonals.unbind()): modules.append(Diagonal(diagonal_init=diagonal)) cur_increasing_stride = increasing_stride != (i % 2 == 1) h = hadamard(n, normalized, cur_increasing_stride) modules.append(h) return nn.Sequential(*modules)
def __init__(self, in_size, in_channels, out_channels, arch_init='ortho', weight_init=nn.init.kaiming_normal_, kmatrix_depth=1, base=2, max_kernel_size=1, padding=None, stride=1, arch_shape=None, weight=None, global_biasing='additive', channel_gating='complete', perturb=0.0, crop_init=slice(0), dilation_init=1, padding_mode='circular', bias=None, checkpoint=False, fourier_position=-1, _swap=False): ''' Args: in_size: input size in_channels: number of input channels out_channels: number of output_channels arch_init: 'ortho' or $OPTYPE (e.g. 'skip') or $OPTYPE'_'$KERNELSIZE (e.g. 'conv_3x3') weight_init: function that initializes weight tensor kmatrix_depth: depth of each kmatrix base: base to use for kmatrix (must be 2 or 4) max_kernel_size: maximum kernel size padding: determines padding; by default sets padding according to arch_init stride: governs subsampling arch_shape: architecture that determines the output shape; uses arch_init by default weight: model weights global_biasing: 'additive' or 'interp' or False channel_gating: 'complete' or 'interp' or False perturb: scale of perturbation to arch params crop_init: input slice(s) to crop dilation_init: kernel dilation at initialization padding_mode: 'circular' or 'zeros'; for 'zeros' will adjust in_size as needed bias: optional bias parameter checkpoint: apply checkpointing to kmatrix operations fourier_position: where to put each Fourier matrix when warm starting; -1 applies it last ''' if not _swap: # '_swap' variable allows for fast re-initialization of a module; useful for computing metrics super(XD, self).__init__() self._init_args = (in_size, in_channels, out_channels) self._init_kwargs = { 'arch_shape': arch_init, 'padding': padding, 'crop_init': crop_init, 'dilation_init': dilation_init, 'padding_mode': padding_mode, 'checkpoint': checkpoint, 'fourier_position': fourier_position } assert base in {2, 4}, "'base' must be 2 or 4" assert global_biasing in {'additive', 'interp', False}, "invalid value for 'global_biasing'" assert channel_gating in {'complete', 'interp', False}, "invalid value for 'channel_gating'" self.checkpoint = checkpoint self.base = base self.chan = (out_channels, in_channels) self.depth = int2tuple(kmatrix_depth, length=3) self.dims = 2 if type(in_size) == int else len(in_size) in_size = int2tuple(in_size, length=self.dims) if padding_mode == 'zeros': # increases effective input size if required due to zero-padding padding = int2tuple(0 if padding is None else padding, length=self.dims) in_size = tuple(n + 2 * p for n, p in zip(in_size, padding)) self.zero_pad = tuple(sum(([p, p] for p in padding), [])) padding = [0] * self.dims else: self.zero_pad = () self.in_size = tuple(2**math.ceil(math.log2(n)) for n in in_size) crop_init = int2tuple(crop_init, length=self.dims) dilation_init = tuple( reversed(int2tuple(dilation_init, length=self.dims))) self.max_kernel_size, kd_init, skips, fourier_init, diagonal_init, self.unpadding = self._parse_init( arch_init, max_kernel_size, padding, arch_shape, dilation_init, _swap) zeroL = diagonal_init and global_biasing == 'additive' self.nd = tuple(reversed(self.in_size)) self.kd = tuple(reversed(self.max_kernel_size)) self.pd = tuple(k // 2 for k in self.kd) self.stride = int2tuple(stride, length=self.dims) if self.dims > 3: assert all( s == 1 for s in self.stride), "must have stride 1 if using >3 dims" self.subsample = nn.Sequential( ) # TODO: handle stride>1 for >3 dimensional XD-op else: self.subsample = AvgPool(self.dims)(kernel_size=[1] * self.dims, stride=self.stride) if not _swap: self.weight = nn.Parameter( torch.Tensor(out_channels, in_channels, *self.max_kernel_size)) weight_init(self.weight) if not weight is None: if type(weight ) == nn.Parameter and self.weight.shape == weight.shape: self.weight = weight else: self._offset_insert(self.weight.data, weight.data.to(self.weight.device)) self.bias = nn.Parameter(bias) if type(bias) == torch.Tensor else bias channels = min(self.chan) inoff, outoff = int(0.5 * (in_channels - channels)), int( 0.5 * (out_channels - channels)) if not _swap: self.register_buffer('diag', None, persistent=False) self.register_buffer('kron', None, persistent=False) self.register_buffer('_one', self.r2c(torch.ones(1))) self.register_buffer('_1', self.r2c(torch.ones(self.chan))) self.register_buffer('_I', self.r2c(torch.zeros(self.chan))) self._I[outoff:outoff + channels, inoff:inoff + channels] = torch.eye(channels) for (kmatrix_name, diags), depth, fpos in zip( [ ('K', [self.diag_K(n, s) for n, s in zip(self.nd, skips)]), # handles strides ('L', [ torch.zeros(n) if zeroL else self.diag_L(n, k) for n, k in zip(self.nd, kd_init) ]), # handles kernel size limits ('M', [self.diag_M(n, c) for n, c in zip(self.nd, crop_init)]) ], # handles input cropping self.depth, int2tuple(fourier_position, length=3)): if _swap: kmatrix = getattr(self, kmatrix_name) else: kmatrix_kwargs = { 'bias': False, 'increasing_stride': kmatrix_name == 'K', 'complex': True, 'init': 'identity' if fourier_init else arch_init, 'nblocks': depth, } kmatrix = TensorProduct(*(Butterfly(n, n, **kmatrix_kwargs) for n in self.nd)) if fourier_init: fourier_kmatrix = self.get_fourier( kmatrix_name, *self.nd, diags=[ self._perturb( diag if d == 1 else torch.ones(diag.shape), perturb) for d, diag in zip(dilation_init, diags) ]) if kmatrix_name == 'L' and any(d > 1 for d in dilation_init): fpos = max(2, depth + fpos if fpos < 0 else fpos) for dim, d, k, n in zip(range(1, self.dims + 1), dilation_init, self.kd, self.nd): if kmatrix_name == 'L' and d > 1: # handles initialization of middle K-matrix for the case of dilated convs; requires kmatrix_depth >= 3 assert depth >= 3, "using dilation > 1 requires depth at least (1, 3, 1)" kmatrix.getmap( dim).twiddle.data[:, :2] = diagonal_butterfly( perm2butterfly(self._atrous_permutation( n, k, d), complex=True), diags[dim - 1], diag_first=True).twiddle.data.to( kmatrix.device()) kmatrix.getmap(dim).twiddle.data[ 0, fpos] = fourier_kmatrix.getmap(dim).twiddle.data[ 0, 0].to(kmatrix.device()) if base == 4: for dim in range(1, self.dims + 1): kmatrix.setmap(dim, kmatrix.getmap(dim).to_base4()) setattr(self, kmatrix_name, kmatrix) self.global_biasing = global_biasing filt = self._offset_insert( torch.zeros(1, 1, *self.max_kernel_size), torch.ones(1, 1, *kd_init) / np.prod(kd_init) if 'pool' in arch_init else torch.ones(1, 1, *[1] * self.dims)) if self.global_biasing == 'additive': if diagonal_init: L = self.get_fourier('L', *self.nd, diags=[ self.diag_L(n, k) for n, k in zip(self.nd, kd_init) ]) b = L(self.r2c(self._circular_pad(filt))) else: b = self.r2c(torch.zeros(1, 1, *self.in_size)) elif self.global_biasing == 'interp': if diagonal_init: b = self.r2c(torch.cat((torch.ones(1), filt.flatten()))) else: b = self.r2c(torch.zeros(1 + np.prod(self.max_kernel_size))) else: b = self.r2c(torch.Tensor(0)) if _swap: self.b.data = b.to(self.b.device) else: self.register_parameter('b', nn.Parameter(b)) self.channel_gating = channel_gating if self.channel_gating == 'complete': if diagonal_init: C = self.r2c(torch.zeros(self.chan)) C[outoff:outoff + channels, inoff:inoff + channels] = torch.eye(channels) else: C = self.r2c(torch.ones(self.chan)) elif self.channel_gating == 'interp': C = self.r2c(torch.Tensor([float(diagonal_init)])) else: C = self.r2c(torch.Tensor(0)) if _swap: self.C.data = C.to(self.C.device) else: self.register_parameter('C', nn.Parameter(C)) self.to(self.device())
def acdc(diag1: torch.Tensor, diag2: torch.Tensor, dct_first: bool = True, separate_diagonal: bool = True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs either the multiplication: x -> diag2 @ iDCT @ diag1 @ DCT @ x or x -> diag2 @ DCT @ diag1 @ iDCT @ x. In the paper [1], the math describes the 2nd type while the implementation uses the 1st type. Note that the DCT and iDCT are normalized. [1] Marcin Moczulski, Misha Denil, Jeremy Appleyard, Nando de Freitas. ACDC: A Structured Efficient Linear Layer. http://arxiv.org/abs/1511.05946 Parameters: diag1: (n,), where n is a power of 2. diag2: (n,), where n is a power of 2. dct_first: if True, uses the first type above; otherwise use the second type. separate_diagonal: if False, the diagonal is combined into the Butterfly part. """ n, = diag1.shape assert diag2.shape == (n, ) assert n == 1 << int(math.ceil(math.log2(n))), 'n must be a power of 2' # Construct the permutation before the FFT: separate the even and odd and then reverse the odd # e.g., [0, 1, 2, 3] -> [0, 2, 3, 1]. # This permutation is actually in B (not just B^T B or B B^T). This can be checked with # perm2butterfly. perm = torch.arange(n) perm = torch.cat((perm[::2], perm[1::2].flip([0]))) perm_inverse = invert(perm) br = bitreversal_permutation(n, pytorch_format=True) postprocess_diag = 2 * torch.exp(-1j * math.pi * torch.arange(0.0, n) / (2 * n)) # Normalize postprocess_diag[0] /= 2.0 postprocess_diag[1:] /= math.sqrt(2) if dct_first: b_fft = fft(n, normalized=True, br_first=False, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=True, with_br_perm=False) b1 = diagonal_butterfly(b_fft, postprocess_diag[br], diag_first=False) b2 = diagonal_butterfly(b_ifft, postprocess_diag.conj()[br], diag_first=True) if not separate_diagonal: b1 = diagonal_butterfly(b_fft, diag1[br], diag_first=False) b2 = diagonal_butterfly(b2, diag2[perm], diag_first=False) return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real(), FixedPermutation(perm_inverse)) else: return nn.Sequential(FixedPermutation(perm), Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2[perm]), FixedPermutation(perm_inverse)) else: b_fft = fft(n, normalized=True, br_first=True, with_br_perm=False) b_ifft = ifft(n, normalized=True, br_first=False, with_br_perm=False) b1 = diagonal_butterfly(b_ifft, postprocess_diag.conj(), diag_first=True) b2 = diagonal_butterfly(b_fft, postprocess_diag, diag_first=False) if not separate_diagonal: b1 = diagonal_butterfly(b1, diag1[perm][br], diag_first=False) b2 = diagonal_butterfly(b_fft, diag2, diag_first=False) return nn.Sequential(Real2Complex(), b1, Complex2Real(), Real2Complex(), b2, Complex2Real()) else: return nn.Sequential(Real2Complex(), b1, Complex2Real(), Diagonal(diagonal_init=diag1[perm][br]), Real2Complex(), b2, Complex2Real(), Diagonal(diagonal_init=diag2))
def circulant(col, transposed=False, separate_diagonal=True) -> nn.Module: """ Construct an nn.Module based on Butterfly that exactly performs circulant matrix multiplication. Parameters: col: torch.Tensor of size (n, ). The first column of the circulant matrix. transposed: if True, then the circulant matrix is transposed, i.e. col is the first *row* of the matrix. separate_diagonal: if True, the returned nn.Module is Butterfly, Diagonal, Butterfly. if False, the diagonal is combined into the Butterfly part. """ assert col.dim() == 1, 'Vector col must have dimension 1' complex = col.is_complex() n = col.shape[0] log_n = int(math.ceil(math.log2(n))) # For non-power-of-2, maybe there's a way to only pad up to size 1 << log_n? # I've only figured out how to pad to size 1 << (log_n + 1). # e.g., [a, b, c] -> [a, b, c, 0, 0, a, b, c] n_extended = n if n == 1 << log_n else 1 << (log_n + 1) b_fft = fft(n_extended, normalized=True, br_first=False, with_br_perm=False).to(col.device) b_fft.in_size = n b_ifft = ifft(n_extended, normalized=True, br_first=True, with_br_perm=False).to(col.device) b_ifft.out_size = n if n < n_extended: col_0 = F.pad(col, (0, 2 * ((1 << log_n) - n))) col = torch.cat((col_0, col)) if not col.is_complex(): col = real2complex(col) # This fft must have normalized=False for the correct scaling. These are the eigenvalues of the # circulant matrix. col_f = torch.fft.fft(col, norm=None) if transposed: # We could have just transposed the iFFT * Diag * FFT to get FFT * Diag * iFFT. # Instead we use the fact that row is the reverse of col, but the 0-th element stays put. # This corresponds to the same reversal in the frequency domain. # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal col_f = torch.cat((col_f[:1], col_f[1:].flip([0]))) br_perm = bitreversal_permutation(n_extended, pytorch_format=True).to(col.device) # diag = col_f[..., br_perm] diag = index_last_dim(col_f, br_perm) if separate_diagonal: if not complex: return nn.Sequential(Real2Complex(), b_fft, Diagonal(diagonal_init=diag), b_ifft, Complex2Real()) else: return nn.Sequential(b_fft, Diagonal(diagonal_init=diag), b_ifft) else: # Combine the diagonal with the last twiddle factor of b_fft with torch.no_grad(): b_fft = diagonal_butterfly(b_fft, diag, diag_first=False, inplace=True) # Combine the b_fft and b_ifft into one Butterfly (with nblocks=2). b = butterfly_product(b_fft, b_ifft) b.in_size = n b.out_size = n return b if complex else nn.Sequential(Real2Complex(), b, Complex2Real())