def test_butterfly_fft(): # DFT matrix for n = 4 size = 4 DFT = torch.fft(real_to_complex(torch.eye(size)), 1) P = real_to_complex( torch.tensor([[1., 0., 0., 0.], [0., 0., 1., 0.], [0., 1., 0., 0.], [0., 0., 0., 1.]])) M0 = Butterfly(size, diagonal=2, complex=True, diag=torch.tensor( [[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True)) M1 = Butterfly(size, diagonal=1, complex=True, diag=torch.tensor( [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True)) assert torch.allclose( complex_matmul(M0.matrix(), complex_matmul(M1.matrix(), P)), DFT) br_perm = torch.tensor(bitreversal_permutation(size)) assert torch.allclose( complex_matmul(M0.matrix(), M1.matrix())[:, br_perm], DFT) D = complex_matmul(DFT, P.transpose(0, 1)) assert torch.allclose(complex_matmul(M0.matrix(), M1.matrix()), D)
def test_blockpermproduct(): size = 8 input = torch.randn(3, size, 2) perm = BlockPermProduct(size, complex=True, share_logit=True) perm.logit[0] = float('inf') from utils import bitreversal_permutation assert torch.allclose(perm(input), input[:, bitreversal_permutation(size)])
def __init__(self, size, complex=False, decreasing_size=True): super().__init__() m = int(math.log2(size)) assert size == 1 << m, "size must be a power of 2" self.size = size self.complex = complex sizes = [size >> i for i in range(m) ] if decreasing_size else [size >> i for i in range(m)[::-1]] self.factors = nn.ModuleList( [Block2x2DiagBmm(size_, complex=complex) for size_ in sizes]) self.br_perm = bitreversal_permutation(size)
def test_butterfly_dct(): from scipy.fftpack import dct # DCT matrix for n = 4 size = 4 # Need to transpose as dct acts on rows of matrix np.eye, not columns DCT = torch.tensor(dct(np.eye(size)).T, dtype=torch.float) M0diag = torch.tensor([[1.0, 0.0], [1.0, 0.0], [-1.0, 0.0], [0.0, 1.0]], requires_grad=True) M0subdiag = torch.tensor([[1.0, 0.0], [1.0, 0.0]], requires_grad=True) M0superdiag = torch.tensor([[1.0, 0.0], [0.0, -1.0]], requires_grad=True) M0 = Butterfly(size, diagonal=2, complex=True, diag=M0diag, subdiag=M0subdiag, superdiag=M0superdiag) M1 = Butterfly(size, diagonal=1, complex=True, diag=torch.tensor( [[1.0, 0.0], [-1.0, 0.0], [1.0, 0.0], [-1.0, 0.0]], requires_grad=True), subdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True), superdiag=torch.tensor([[1.0, 0.0], [0.0, 0.0], [1.0, 0.0]], requires_grad=True)) arange_ = np.arange(size) dct_perm = np.concatenate((arange_[::2], arange_[::-2])) br_perm = bitreversal_permutation(size) perm = torch.arange(size)[dct_perm][br_perm] arange_ = torch.arange(size, dtype=torch.float) diag_real = 2 * torch.cos(-math.pi * arange_ / (2 * size)) diag_imag = 2 * torch.sin(-math.pi * arange_ / (2 * size)) diag = torch.stack((torch.diag(diag_real), torch.diag(diag_imag)), dim=-1) assert torch.allclose( complex_matmul(diag, complex_matmul(M0.matrix(), M1.matrix()))[:, perm, 0], DCT) D = torch.stack((diag_real, diag_imag), dim=-1) DM0 = Butterfly(size, diagonal=2, complex=True, diag=complex_mul(D, M0diag), subdiag=complex_mul(D[2:], M0subdiag), superdiag=complex_mul(D[:2], M0superdiag)) assert torch.allclose( complex_matmul(DM0.matrix(), M1.matrix())[:, perm, 0], DCT)
def __init__(self, input_size): super().__init__() filter_size = \ max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())) self.butterfly_ifft = Butterfly(filter_size, filter_size, complex=True, tied_weight=False, bias=False) # iFFT self.butterfly_ifft.twiddle = torch.nn.Parameter( fft_twiddle(filter_size, forward=False, normalized=True)) self.butterfly_fft = Butterfly(filter_size, filter_size, complex=True, tied_weight=False, bias=False, increasing_stride=False) # FFT self.butterfly_fft.twiddle = torch.nn.Parameter( butterfly_transpose_conjugate(self.butterfly_ifft.twiddle)) f = fftfreq(filter_size) fourier_filter = self.create_filter(f) br = bitreversal_permutation(filter_size) self.fourier_filter_br = torch.nn.Parameter(fourier_filter[br])
def named_target_matrix(name, size): """ Parameter: name: name of the target matrix Return: target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices. """ if name == 'dft': return LA.dft(size, scale='sqrtn')[:, :, None].view('float64') elif name == 'idft': return np.ascontiguousarray(LA.dft(size, scale='sqrtn').conj().T)[:, :, None].view('float64') elif name == 'dft2': size_sr = int(math.sqrt(size)) matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2) # matrix1d = LA.dft(size_sr, scale='sqrtn') # assert np.allclose(np.kron(m1d, m1d), matrix) # return matrix[:, :, None].view('float64') from butterfly.utils import bitreversal_permutation br_perm = bitreversal_permutation(size_sr) br_perm2 = np.arange(size_sr**2).reshape(size_sr, size_sr)[br_perm][:, br_perm].reshape(-1) matrix = np.ascontiguousarray(matrix[:, br_perm2]) return matrix[:, :, None].view('float64') elif name == 'dct': # Need to transpose as dct acts on rows of matrix np.eye, not columns # return dct(np.eye(size), norm='ortho').T return dct(np.eye(size)).T / math.sqrt(size) elif name == 'dst': return dst(np.eye(size)).T / math.sqrt(size) elif name == 'hadamard': return LA.hadamard(size) / math.sqrt(size) elif name == 'hadamard2': size_sr = int(math.sqrt(size)) matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr) return np.kron(matrix1d, matrix1d) elif name == 'b2': size_sr = int(math.sqrt(size)) from butterfly import Block2x2DiagProduct b = Block2x2DiagProduct(size_sr) matrix1d = b(torch.eye(size_sr)).t().detach().numpy() return np.kron(matrix1d, matrix1d) elif name == 'convolution': np.random.seed(0) x = np.random.randn(size) return LA.circulant(x) / math.sqrt(size) elif name == 'hartley': return hartley_matrix(size) / math.sqrt(size) elif name == 'haar': return haar_matrix(size, normalized=True) / math.sqrt(size) elif name == 'legendre': grid = np.linspace(-1, 1, size + 2)[1:-1] return legendre.legvander(grid, size - 1).T / math.sqrt(size) elif name == 'hilbert': H = hilbert_matrix(size) return H / np.linalg.norm(H, 2) elif name == 'randn': np.random.seed(0) return np.random.randn(size, size) / math.sqrt(size) elif name == 'permutation': np.random.seed(0) perm = np.random.permutation(size) P = np.eye(size)[perm] return P elif name.startswith('rank-unnorm'): r = int(name[11:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T # M /= math.sqrt(size*r) return M elif name.startswith('rank'): r = int(name[4:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T M /= math.sqrt(size*r) return M elif name.startswith('sparse'): s = int(name[6:]) # 2rn parameters np.random.seed(0) mask = sparse.random(size, size, density=s/size, data_rvs=np.ones) M = np.random.randn(size, size) * (mask.toarray()) M /= math.sqrt(s) return M elif name.startswith('toeplitz'): r = int(name[8:]) G = np.random.randn(size, r) / math.sqrt(size*r) H = np.random.randn(size, r) / math.sqrt(size*r) M = toeplitz_like(G, H) return M elif name == 'fastfood': n = size S = np.random.randn(n) G = np.random.randn(n) B = np.random.randn(n) # P = np.arange(n) P = np.random.permutation(n) H = hadamard(n) # SHGPHB # print(H) # print((H*B)[P,:]) # print((H @ (G[:,np.newaxis] * (H * B)[P,:]))) F = S[:,np.newaxis] * (H @ (G[:,np.newaxis] * (H * B)[P,:])) / n return F # x = np.random.randn(batch_size,n) # HB = hadamard_transform(B) # PHBx = HBx[:, P] # HGPHBx = hadamard_transform(G*PHBx) # return S*HGPHBx elif name == 'butterfly': # n (log n+1) params in the hierarchy b = Butterfly(in_size=size, out_size=size, bias=False, tied_weight=False, param='odo', nblocks=0) M = b(torch.eye(size)) return M.cpu().detach().numpy() else: assert False, 'Target matrix name not recognized or implemented'
def _setup(self, config): device = config['device'] self.device = device size = config['size'] if isinstance(config['target_matrix'], str): self.target_matrix = torch.tensor(named_target_matrix( config['target_matrix'], size), dtype=torch.float).to(device) else: self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float).to(device) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' complex = self.target_matrix.dim() == 3 or config['complex'] torch.manual_seed(config['seed']) if config['model'] == 'B': self.model = nn.Sequential( FixedPermutation(torch.tensor(bitreversal_permutation(size))), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'PBT': self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, increasing_stride=False, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][0])).to(device) elif config['model'] == 'BPP': self.model = nn.Sequential( PermutationFactor(size=size), Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BPBP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][1]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BBT': # param_type = 'regular' if complex else 'perm' param_type = config['param'] self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True)) elif config['model'][0] == 'T' and (config['model'][1:]).isdigit(): depth = int(config['model'][1:]) param_type = config['param'] self.model = nn.Sequential(*[ Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False) for _ in range(depth) ]) elif config['model'][0:3] == 'BBT' and (config['model'][3:]).isdigit(): depth = int(config['model'][3:]) param_type = config['param'] self.model = nn.Sequential(*[ nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True)) for _ in range(depth) ]) elif config['model'][0] == 'B' and (config['model'][1:]).isdigit(): depth = int(config['model'][1:]) param_type = config['param'] self.model = nn.Sequential(*[ Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True) for _ in range(depth) ]) elif config['model'] == 'butterfly': # e = int(config['model'][4:]) self.model = Butterfly(in_size=size, out_size=size, complex=complex, **config['bfargs']) # elif config['model'][0:3] == 'ODO': # if (config['model'][3:]).isdigit(): # width = int(config['model'][3:]) # self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=0, expansion=width, diag_init='normal') # elif config['model'][3] == 'k': # k = int(config['model'][4:]) # self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=k, diag_init='normal') # non-butterfly transforms # elif config['model'][0:2] == 'TL' and (config['model'][2:]).isdigit(): # rank = int(config['model'][2:]) elif config['model'][0:4] == 'rank' and ( config['model'][4:]).isdigit(): rank = int(config['model'][4:]) self.model = nn.Sequential( nn.Linear(size, rank, bias=False), nn.Linear(rank, size, bias=False), ) else: assert False, f"Model {config['model']} not implemented" self.nparameters = sum(param.nelement() for param in self.model.parameters()) print("Parameters: ", self.nparameters) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = torch.eye(size).to(device) if complex: self.input = real_to_complex(self.input)
def named_target_matrix(name, size): """ Parameter: name: name of the target matrix Return: target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices. """ if name == 'dft': return LA.dft(size, scale='sqrtn')[:, :, None].view('float64') elif name == 'idft': return np.ascontiguousarray(LA.dft( size, scale='sqrtn').conj().T)[:, :, None].view('float64') elif name == 'dft2': size_sr = int(math.sqrt(size)) matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2) # matrix1d = LA.dft(size_sr, scale='sqrtn') # assert np.allclose(np.kron(m1d, m1d), matrix) # return matrix[:, :, None].view('float64') from butterfly.utils import bitreversal_permutation br_perm = bitreversal_permutation(size_sr) br_perm2 = np.arange(size_sr**2).reshape( size_sr, size_sr)[br_perm][:, br_perm].reshape(-1) matrix = np.ascontiguousarray(matrix[:, br_perm2]) return matrix[:, :, None].view('float64') elif name == 'dct': # Need to transpose as dct acts on rows of matrix np.eye, not columns # return dct(np.eye(size), norm='ortho').T return dct(np.eye(size)).T / math.sqrt(size) elif name == 'dst': return dst(np.eye(size)).T / math.sqrt(size) elif name == 'hadamard': return LA.hadamard(size) / math.sqrt(size) elif name == 'hadamard2': size_sr = int(math.sqrt(size)) matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr) return np.kron(matrix1d, matrix1d) elif name == 'b2': size_sr = int(math.sqrt(size)) import torch from butterfly import Block2x2DiagProduct b = Block2x2DiagProduct(size_sr) matrix1d = b(torch.eye(size_sr)).t().detach().numpy() return np.kron(matrix1d, matrix1d) elif name == 'convolution': np.random.seed(0) x = np.random.randn(size) return LA.circulant(x) / math.sqrt(size) elif name == 'hartley': return hartley_matrix(size) / math.sqrt(size) elif name == 'haar': return haar_matrix(size, normalized=True) / math.sqrt(size) elif name == 'legendre': grid = np.linspace(-1, 1, size + 2)[1:-1] return legendre.legvander(grid, size - 1).T / math.sqrt(size) elif name == 'hilbert': H = hilbert_matrix(size) return H / np.linalg.norm(H, 2) elif name == 'randn': np.random.seed(0) return np.random.randn(size, size) / math.sqrt(size) else: assert False, 'Target matrix name not recognized or implemented'
def _setup(self, config): device = config['device'] self.device = device size = config['size'] if isinstance(config['target_matrix'], str): self.target_matrix = torch.tensor(named_target_matrix( config['target_matrix'], size), dtype=torch.float).to(device) else: self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float).to(device) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' complex = self.target_matrix.dim() == 3 or config['complex'] torch.manual_seed(config['seed']) if config['model'] == 'B': self.model = nn.Sequential( FixedPermutation(torch.tensor(bitreversal_permutation(size))), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'PBT': self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, increasing_stride=False, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][0])).to(device) elif config['model'] == 'BPP': self.model = nn.Sequential( PermutationFactor(size=size), Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BPBP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][1]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) else: assert False, f'Model {model} not implemented' self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = torch.eye(size).to(device) if complex: self.input = real_to_complex(self.input)