def __init__(self, method='linear', **kwargs): super().__init__() if method == 'linear': make_layer = lambda name: self.add_module( name, nn.Linear(1024, 1024, bias=True)) elif method == 'butterfly': make_layer = lambda name: self.add_module( name, Butterfly(1024, 1024, bias=True, **kwargs)) # self.fc = Butterfly(1024, 1024, tied_weight=False, bias=False, param='regular', nblocks=0) # self.fc = Butterfly(1024, 1024, tied_weight=False, bias=False, param='odo', nblocks=1) elif method == 'low-rank': make_layer = lambda name: self.add_module( name, nn.Sequential(nn.Linear(1024, kwargs['rank'], bias=False), nn.Linear(kwargs['rank'], 1024, bias=True))) elif method == 'toeplitz': make_layer = lambda name: self.add_module( name, sl.ToeplitzLikeC(layer_size=1024, bias=True, **kwargs)) else: assert False, f"method {method} not supported" # self.fc10 = make_layer() # self.fc11 = make_layer() # self.fc12 = make_layer() # self.fc2 = make_layer() make_layer('fc10') make_layer('fc11') make_layer('fc12') make_layer('fc2') make_layer('fc3') self.logits = nn.Linear(1024, 10)
def run_raw(in_, out_, batch_size): L = [torch.nn.Linear(in_, out_, bias=False) for _ in range(nsteps)] weights = [i.weight.t() for i in L] x = torch.randn(batch_size, in_, requires_grad=False) x_stack = x.unsqueeze(1).expand((batch_size, max(1, out_//in_), in_)) B_untied = [Butterfly(in_, out_, bias=False, tied_weight=False) for _ in range(nsteps)] twiddle_untied = [B_untied[i].twiddle for i in range(nsteps)] bfly_start = time.perf_counter() for i in range(nsteps): output = butterfly_multiply_untied(twiddle_untied[i], x_stack, True, False) bfly_end = time.perf_counter() bfly_time_train = bfly_end - bfly_start print(f'Butterfly Training Forward: {bfly_time_train}') bfly_start = time.perf_counter() for i in range(nsteps): output = butterfly_multiply_untied_eval(twiddle_untied[i], x_stack, True) bfly_end = time.perf_counter() bfly_time_eval = bfly_end - bfly_start print(f'Butterfly Inference Forward: {bfly_time_eval}') gemm_start = time.perf_counter() for i in range(nsteps): output = x.matmul(weights[i]) gemm_end = time.perf_counter() gemm_time = gemm_end - gemm_start print(f'Linear Forward: {gemm_time}') print(f'Dim: {in_, out_} Batch Size: {batch_size} Speedup: {gemm_time / bfly_time_eval}x')
def enter(): global map map = Map() game_world.add_object(map, 0) global horn horn = Horn() game_world.add_object(horn, 1) global boss boss = Boss() game_world.add_object(boss, 1) global character character = Character() game_world.add_object(character, 1) global butterflys butterflys = [Butterfly(character) for i in range(7)] game_world.add_objects(butterflys, 1) key = Key() game_world.add_object(key, 1) map.set_center_object(character) character.set_background(map) horn.set_background(map) boss.set_background(map)
def run(in_, out_, batch_size): # create multiple models so the weights aren't already loaded in the cache L = [torch.nn.Linear(in_, out_, bias=False) for _ in range(nsteps)] B_untied = [Butterfly(in_, out_, bias=False, tied_weight=False) for _ in range(nsteps)] twiddle_untied = [B_untied[i].twiddle for i in range(nsteps)] x = torch.randn(batch_size, in_, requires_grad=False) bfly_start = time.perf_counter() for i in range(nsteps): output = B_untied[i](x) bfly_end = time.perf_counter() bfly_time_train = bfly_end - bfly_start print(f'Butterfly Training Forward: {bfly_time_train}') B_untied = [i.eval() for i in B_untied] bfly_start = time.perf_counter() for i in range(nsteps): output = B_untied[i](x) bfly_end = time.perf_counter() bfly_time_eval = bfly_end - bfly_start print(f'Butterfly Inference Forward: {bfly_time_eval}') output = L[-1](x) gemm_start = time.perf_counter() for i in range(nsteps): output = L[i](x) gemm_end = time.perf_counter() gemm_time = gemm_end - gemm_start print(f'Linear Forward: {gemm_time}') print(f'Dim: {in_, out_} Batch Size: {batch_size} Speedup: {gemm_time / bfly_time_eval}x')
def hadamard_test(): # Hadamard matrix for n = 4 size = 4 M0 = Butterfly(size, diagonal=2, diag=torch.tensor([1.0, 1.0, -1.0, -1.0], requires_grad=True), subdiag=torch.ones(2, requires_grad=True), superdiag=torch.ones(2, requires_grad=True)) M1 = Butterfly(size, diagonal=1, diag=torch.tensor([1.0, -1.0, 1.0, -1.0], requires_grad=True), subdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True), superdiag=torch.tensor([1.0, 0.0, 1.0], requires_grad=True)) H = M0.matrix() @ M1.matrix() assert torch.allclose(H, torch.tensor(hadamard(4), dtype=torch.float)) M = ButterflyProduct(size, fixed_order=True) M.factors[0] = M0 M.factors[1] = M1 assert torch.allclose(M.matrix(), H)
def __init__(self, input_size): super().__init__() filter_size = \ max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil())) self.butterfly_ifft = Butterfly(filter_size, filter_size, complex=True, tied_weight=False, bias=False) # iFFT self.butterfly_ifft.twiddle = torch.nn.Parameter( fft_twiddle(filter_size, forward=False, normalized=True)) self.butterfly_fft = Butterfly(filter_size, filter_size, complex=True, tied_weight=False, bias=False, increasing_stride=False) # FFT self.butterfly_fft.twiddle = torch.nn.Parameter( butterfly_transpose_conjugate(self.butterfly_ifft.twiddle)) f = fftfreq(filter_size) fourier_filter = self.create_filter(f) br = bitreversal_permutation(filter_size) self.fourier_filter_br = torch.nn.Parameter(fourier_filter[br])
def run_game(): pygame.init() screen = pygame.display.set_mode((1200, 800)) bg_color = (255, 255, 255) pygame.display.set_caption("Learn Screen") butterfly = Butterfly(screen) while True: for event in pygame.event.get(): if event.type == pygame.QUIT: sys.exit() screen.fill(bg_color) butterfly.blitme() pygame.display.flip()
def test_butterfly_complex_inplace_cpu(self): batch_size = 10 n = 4096 # TODO: in-place implementation doesn't support nstack for now nstack = 1 b = Butterfly(n, n, bias=False, complex=True, ortho_init=True) twiddle = b.twiddle input = torch.randn(batch_size, n, 2, requires_grad=True) output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input) output_torch = butterfly_mult_torch(twiddle, input).squeeze(1) self.assertTrue( torch.allclose(output_inplace, output_torch, rtol=self.rtol, atol=self.atol), (output_inplace - output_torch).abs().max().item())
def __init__(self, num_classes=1000, width_mult=1.0, round_nearest=8, structure=None, softmax_structure='D', sm_pooling=1): """ structure: list of string """ super(MobileNet, self).__init__() self.width_mult = width_mult self.round_nearest = round_nearest self.structure = [] if structure is None else structure self.n_structure_layer = len(self.structure) self.structure = ['D'] * (len(self.cfg) - self.n_structure_layer) + self.structure self.sm_pooling = sm_pooling input_channel = _make_divisible(32 * width_mult, round_nearest) self.conv1 = nn.Conv2d(3, input_channel, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(input_channel) self.bn1.weight._no_wd = True self.bn1.bias._no_wd = True self.layers = self._make_layers(in_planes=input_channel) self.last_channel = _make_divisible(1024 * width_mult // sm_pooling, round_nearest) if softmax_structure == 'D': self.linear = nn.Linear(self.last_channel, num_classes) else: param = softmax_structure.split('_')[0] nblocks = 0 if len(softmax_structure.split('_')) <= 1 else int( softmax_structure.split('_')[1]) self.linear = Butterfly(self.last_channel, num_classes, tied_weight=False, ortho_init=True, param=param, nblocks=nblocks)
def decorated_function(*args, **kwargs): if not 'hotel_id' in kwargs: return f(*args, **kwargs) g.hotel_id = kwargs['hotel_id'] g.hotel = Hotel.query.filter_by(id=g.hotel_id).first() if not g.hotel: return o_render_template('hotel_not_found.html'), 404 if not g.hotel.validate_butterfly(): return redirect(url_for('hotel_misconfigured', hotel_id=g.hotel_id)) g.butterfly = Butterfly(g.hotel.butterfly_user, g.hotel.butterfly_token, g.hotel.butterfly_url) g.proxies = Proxies(g.butterfly) return f(*args, **kwargs)
def test_butterfly_inplace_cuda(self): batch_size = 10 n = 4096 # TODO: in-place implementation doesn't support nstack for now nstack = 1 b = Butterfly(n, n, bias=False, ortho_init=True).to('cuda') twiddle = b.twiddle input = torch.randn(batch_size, n, requires_grad=True, device=twiddle.device) output_inplace = butterfly_mult_inplace(twiddle.squeeze(0), input) output_torch = butterfly_mult_torch(twiddle, input).squeeze(1) self.assertTrue( torch.allclose(output_inplace, output_torch, rtol=self.rtol, atol=self.atol), (output_inplace - output_torch).abs().max().item()) grad = torch.randn_like(output_torch) d_twiddle_inplace, d_input_inplace = torch.autograd.grad( output_inplace, (twiddle, input), grad, retain_graph=True) d_twiddle_torch, d_input_torch = torch.autograd.grad(output_torch, (twiddle, input), grad, retain_graph=True) self.assertTrue( torch.allclose(d_input_inplace, d_input_torch, rtol=self.rtol, atol=self.atol), (d_input_inplace - d_input_torch).abs().max().item()) # print((d_twiddle_inplace - d_twiddle_torch) / d_twiddle_torch) self.assertTrue( torch.allclose(d_twiddle_inplace, d_twiddle_torch, rtol=self.rtol, atol=self.atol), ((d_twiddle_inplace - d_twiddle_torch) / d_twiddle_torch).abs().max().item())
def test_butterfly_expansion(self): batch_size = 1 device = 'cpu' in_size, out_size = (16, 16) expansion = 4 b = Butterfly(in_size, out_size, bias=False, tied_weight=True, param='odo', expansion=expansion, diag_init='normal').to(device) input = torch.randn((batch_size, in_size), device=device) output = b(input) terms = [] for i in range(expansion): temp = butterfly_ortho_mult_tied(b.twiddle[[i]], input.unsqueeze(1), False) temp = temp * b.diag[i] temp = butterfly_ortho_mult_tied(b.twiddle1[[i]], temp, True) terms.append(temp) total = sum(terms) self.assertTrue(torch.allclose(output, total))
def __init__(self, method='linear', **kwargs): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5, padding=2) self.conv2 = nn.Conv2d(6, 16, 5, padding=2) # print(method, tied_weight, kwargs) if method == 'linear': self.fc = nn.Linear(1024, 1024) elif method == 'butterfly': self.fc = Butterfly(1024, 1024, bias=True, **kwargs) # self.fc = Butterfly(1024, 1024, tied_weight=False, bias=False, param='regular', nblocks=0) # self.fc = Butterfly(1024, 1024, tied_weight=False, bias=False, param='odo', nblocks=1) elif method == 'low-rank': self.fc = nn.Sequential( nn.Linear(1024, kwargs['rank'], bias=False), nn.Linear(kwargs['rank'], 1024)) elif method == 'toeplitz': self.fc = sl.ToeplitzLikeC(layer_size=1024, bias=True, **kwargs) else: assert False, f"method {method} not supported" # self.bias = nn.Parameter(torch.zeros(1024)) self.logits = nn.Linear(1024, 10)
def test_butterfly(self): batch_size = 10 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for tied_weight in [True, False]: for increasing_stride in [True, False]: for ortho_init in [False, True]: b = Butterfly(in_size, out_size, True, complex, tied_weight, increasing_stride, ortho_init).to(device) input = torch.randn((batch_size, in_size) + (() if not complex else (2, )), device=device) output = b(input) self.assertTrue( output.shape == (batch_size, out_size) + (() if not complex else (2, )), (output.shape, device, (in_size, out_size), complex, tied_weight, ortho_init)) if ortho_init: twiddle_np = b.twiddle.detach().to( 'cpu').numpy() if complex: twiddle_np = twiddle_np.view( 'complex64').squeeze(-1) twiddle_np = twiddle_np.reshape(-1, 2, 2) twiddle_norm = np.linalg.norm(twiddle_np, ord=2, axis=(1, 2)) self.assertTrue( np.allclose(twiddle_norm, 1), (twiddle_norm, device, (in_size, out_size), complex, tied_weight, ortho_init))
def generate_enemy(self, stage_num): print(stage_num) self.stage_num = stage_num - 1 enemy_dic = enemy_generation_table[stage_num] for part_num in range(len(self.enemies)): enemy_part = enemy_dic[part_num] number = 0 for enemy_type in enemy_part: enemy = None if enemy_type == BEE: enemy = Bee(enemy_position_table[part_num][number]) elif enemy_type == BFLY: enemy = Butterfly(enemy_position_table[part_num][number]) elif enemy_type == MOTH: enemy = Moth(enemy_position_table[part_num][number]) else: pass self.enemies[part_num].append(enemy) number += 1 gameworld.add_objects(self.enemies[part_num], 1)
import os, sys project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) import torch from butterfly import Butterfly from butterfly.butterfly_multiply import butterfly_mult, butterfly_mult_untied, butterfly_mult_factors, bbt_mult_untied, bbt_ortho_mult_untied batch_size = 2048 n = 512 B = Butterfly(n, n, bias=False).to('cuda') L = torch.nn.Linear(n, n, bias=False).to('cuda') x = torch.randn(batch_size, n, requires_grad=True).to('cuda') twiddle = B.twiddle B_untied = Butterfly(n, n, bias=False, tied_weight=False).to('cuda') twiddle_untied = B_untied.twiddle B_ortho = Butterfly(n, n, bias=False, tied_weight=False, param='ortho').to('cuda') # twiddle = torch.randn(2, 2, n - 1, device=x.device, requires_grad=True).permute(2, 0, 1) import time nsteps = 1000 # nsteps = 1 grad = torch.randn_like(x) torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): output = butterfly_mult_factors(twiddle.squeeze(0), x)
def named_target_matrix(name, size): """ Parameter: name: name of the target matrix Return: target_matrix: (n, n) numpy array for real matrices or (n, n, 2) for complex matrices. """ if name == 'dft': return LA.dft(size, scale='sqrtn')[:, :, None].view('float64') elif name == 'idft': return np.ascontiguousarray(LA.dft(size, scale='sqrtn').conj().T)[:, :, None].view('float64') elif name == 'dft2': size_sr = int(math.sqrt(size)) matrix = np.fft.fft2(np.eye(size_sr**2).reshape(-1, size_sr, size_sr), norm='ortho').reshape(-1, size_sr**2) # matrix1d = LA.dft(size_sr, scale='sqrtn') # assert np.allclose(np.kron(m1d, m1d), matrix) # return matrix[:, :, None].view('float64') from butterfly.utils import bitreversal_permutation br_perm = bitreversal_permutation(size_sr) br_perm2 = np.arange(size_sr**2).reshape(size_sr, size_sr)[br_perm][:, br_perm].reshape(-1) matrix = np.ascontiguousarray(matrix[:, br_perm2]) return matrix[:, :, None].view('float64') elif name == 'dct': # Need to transpose as dct acts on rows of matrix np.eye, not columns # return dct(np.eye(size), norm='ortho').T return dct(np.eye(size)).T / math.sqrt(size) elif name == 'dst': return dst(np.eye(size)).T / math.sqrt(size) elif name == 'hadamard': return LA.hadamard(size) / math.sqrt(size) elif name == 'hadamard2': size_sr = int(math.sqrt(size)) matrix1d = LA.hadamard(size_sr) / math.sqrt(size_sr) return np.kron(matrix1d, matrix1d) elif name == 'b2': size_sr = int(math.sqrt(size)) from butterfly import Block2x2DiagProduct b = Block2x2DiagProduct(size_sr) matrix1d = b(torch.eye(size_sr)).t().detach().numpy() return np.kron(matrix1d, matrix1d) elif name == 'convolution': np.random.seed(0) x = np.random.randn(size) return LA.circulant(x) / math.sqrt(size) elif name == 'hartley': return hartley_matrix(size) / math.sqrt(size) elif name == 'haar': return haar_matrix(size, normalized=True) / math.sqrt(size) elif name == 'legendre': grid = np.linspace(-1, 1, size + 2)[1:-1] return legendre.legvander(grid, size - 1).T / math.sqrt(size) elif name == 'hilbert': H = hilbert_matrix(size) return H / np.linalg.norm(H, 2) elif name == 'randn': np.random.seed(0) return np.random.randn(size, size) / math.sqrt(size) elif name == 'permutation': np.random.seed(0) perm = np.random.permutation(size) P = np.eye(size)[perm] return P elif name.startswith('rank-unnorm'): r = int(name[11:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T # M /= math.sqrt(size*r) return M elif name.startswith('rank'): r = int(name[4:]) np.random.seed(0) G = np.random.randn(size, r) H = np.random.randn(size, r) M = G @ H.T M /= math.sqrt(size*r) return M elif name.startswith('sparse'): s = int(name[6:]) # 2rn parameters np.random.seed(0) mask = sparse.random(size, size, density=s/size, data_rvs=np.ones) M = np.random.randn(size, size) * (mask.toarray()) M /= math.sqrt(s) return M elif name.startswith('toeplitz'): r = int(name[8:]) G = np.random.randn(size, r) / math.sqrt(size*r) H = np.random.randn(size, r) / math.sqrt(size*r) M = toeplitz_like(G, H) return M elif name == 'fastfood': n = size S = np.random.randn(n) G = np.random.randn(n) B = np.random.randn(n) # P = np.arange(n) P = np.random.permutation(n) H = hadamard(n) # SHGPHB # print(H) # print((H*B)[P,:]) # print((H @ (G[:,np.newaxis] * (H * B)[P,:]))) F = S[:,np.newaxis] * (H @ (G[:,np.newaxis] * (H * B)[P,:])) / n return F # x = np.random.randn(batch_size,n) # HB = hadamard_transform(B) # PHBx = HBx[:, P] # HGPHBx = hadamard_transform(G*PHBx) # return S*HGPHBx elif name == 'butterfly': # n (log n+1) params in the hierarchy b = Butterfly(in_size=size, out_size=size, bias=False, tied_weight=False, param='odo', nblocks=0) M = b(torch.eye(size)) return M.cpu().detach().numpy() else: assert False, 'Target matrix name not recognized or implemented'
import os, sys project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) import torch from butterfly import Butterfly from butterfly.butterfly_multiply import butterfly_mult, butterfly_mult_factors, butterfly_mult_inplace batch_size = 256 n = 1024 B = Butterfly(n, n, bias=False).to('cuda') L = torch.nn.Linear(n, n, bias=False).to('cuda') x = torch.randn(batch_size, n, requires_grad=True).to('cuda') twiddle = B.twiddle # twiddle = torch.randn(2, 2, n - 1, device=x.device, requires_grad=True).permute(2, 0, 1) import time nsteps = 1000 # nsteps = 1 grad = torch.randn_like(x) torch.cuda.synchronize() start = time.perf_counter() for _ in range(nsteps): output = butterfly_mult_factors(twiddle.squeeze(0), x) torch.cuda.synchronize() end = time.perf_counter() print(f'Butterfly mult factors forward: {end - start}s') torch.cuda.synchronize()
def _setup(self, config): device = config['device'] self.device = device size = config['size'] if isinstance(config['target_matrix'], str): self.target_matrix = torch.tensor(named_target_matrix( config['target_matrix'], size), dtype=torch.float).to(device) else: self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float).to(device) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' complex = self.target_matrix.dim() == 3 or config['complex'] torch.manual_seed(config['seed']) if config['model'] == 'B': self.model = nn.Sequential( FixedPermutation(torch.tensor(bitreversal_permutation(size))), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'PBT': self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, increasing_stride=False, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][0])).to(device) elif config['model'] == 'BPP': self.model = nn.Sequential( PermutationFactor(size=size), Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BPBP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][1]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BBT': # param_type = 'regular' if complex else 'perm' param_type = config['param'] self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True)) elif config['model'][0] == 'T' and (config['model'][1:]).isdigit(): depth = int(config['model'][1:]) param_type = config['param'] self.model = nn.Sequential(*[ Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False) for _ in range(depth) ]) elif config['model'][0:3] == 'BBT' and (config['model'][3:]).isdigit(): depth = int(config['model'][3:]) param_type = config['param'] self.model = nn.Sequential(*[ nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=False), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True)) for _ in range(depth) ]) elif config['model'][0] == 'B' and (config['model'][1:]).isdigit(): depth = int(config['model'][1:]) param_type = config['param'] self.model = nn.Sequential(*[ Butterfly(in_size=size, out_size=size, bias=False, complex=complex, param=param_type, increasing_stride=True) for _ in range(depth) ]) elif config['model'] == 'butterfly': # e = int(config['model'][4:]) self.model = Butterfly(in_size=size, out_size=size, complex=complex, **config['bfargs']) # elif config['model'][0:3] == 'ODO': # if (config['model'][3:]).isdigit(): # width = int(config['model'][3:]) # self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=0, expansion=width, diag_init='normal') # elif config['model'][3] == 'k': # k = int(config['model'][4:]) # self.model = Butterfly(in_size=size, out_size=size, bias=False, complex=False, param='odo', tied_weight=True, nblocks=k, diag_init='normal') # non-butterfly transforms # elif config['model'][0:2] == 'TL' and (config['model'][2:]).isdigit(): # rank = int(config['model'][2:]) elif config['model'][0:4] == 'rank' and ( config['model'][4:]).isdigit(): rank = int(config['model'][4:]) self.model = nn.Sequential( nn.Linear(size, rank, bias=False), nn.Linear(rank, size, bias=False), ) else: assert False, f"Model {config['model']} not implemented" self.nparameters = sum(param.nelement() for param in self.model.parameters()) print("Parameters: ", self.nparameters) self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = torch.eye(size).to(device) if complex: self.input = real_to_complex(self.input)
def test_butterfly(self): batch_size = 10 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for tied_weight in [True, False]: for increasing_stride in [True, False]: for ortho_init in [False, True]: for param in ['regular'] if complex else [ 'regular', 'ortho', 'odo', 'obdobt' ]: for nblocks in [0, 1, 2, 3] if param in [ 'regular', 'ortho', 'odo', 'obdobt' ] else [0]: for expansion in [1, 2]: if param in ['obdobt' ] and tied_weight: continue if nblocks > 0 and complex: continue if not (nblocks > 0 and tied_weight and param in ['odo'] ): # Special case if nblocks > 0 and ( tied_weight or param not in [ 'regular', 'ortho', 'odo', 'obdobt' ]): continue b = Butterfly( in_size, out_size, True, complex, tied_weight, increasing_stride, ortho_init, param, nblocks=nblocks, expansion=expansion).to(device) input = torch.randn( (batch_size, in_size) + (() if not complex else (2, )), device=device) output = b(input) self.assertTrue( output.shape == (batch_size, out_size) + (() if not complex else (2, )), (output.shape, device, (in_size, out_size), complex, tied_weight, ortho_init, nblocks)) if ortho_init and param == 'regular': twiddle_np = b.twiddle.detach( ).to('cpu').numpy() if complex: twiddle_np = twiddle_np.view( 'complex64').squeeze( -1) twiddle_np = twiddle_np.reshape( -1, 2, 2) twiddle_norm = np.linalg.norm( twiddle_np, ord=2, axis=(1, 2)) self.assertTrue( np.allclose( twiddle_norm, 1), (twiddle_norm, device, (in_size, out_size), complex, tied_weight, ortho_init))
def _setup(self, config): device = config['device'] self.device = device size = config['size'] if isinstance(config['target_matrix'], str): self.target_matrix = torch.tensor(named_target_matrix( config['target_matrix'], size), dtype=torch.float).to(device) else: self.target_matrix = torch.tensor(config['target_matrix'], dtype=torch.float).to(device) assert self.target_matrix.shape[0] == self.target_matrix.shape[ 1], 'Only square matrices are supported' assert self.target_matrix.dim() in [ 2, 3 ], 'target matrix must be 2D if real of 3D if complex' complex = self.target_matrix.dim() == 3 or config['complex'] torch.manual_seed(config['seed']) if config['model'] == 'B': self.model = nn.Sequential( FixedPermutation(torch.tensor(bitreversal_permutation(size))), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'PBT': self.model = nn.Sequential( Butterfly(in_size=size, out_size=size, bias=False, complex=complex, increasing_stride=False, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][0])).to(device) elif config['model'] == 'BPP': self.model = nn.Sequential( PermutationFactor(size=size), Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) elif config['model'] == 'BPBP': self.model = nn.Sequential( Permutation(size=size, share_logit=config['share_logit'][0]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True), Permutation(size=size, share_logit=config['share_logit'][1]), Butterfly(in_size=size, out_size=size, bias=False, complex=complex, ortho_init=True)).to(device) else: assert False, f'Model {model} not implemented' self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr']) self.n_steps_per_epoch = config['n_steps_per_epoch'] self.n_epochs_per_validation = config['n_epochs_per_validation'] self.input = torch.eye(size).to(device) if complex: self.input = real_to_complex(self.input)
def __init__(self, num_classes=10, dropout=False, method='linear', tied_weight=False, **kwargs): super(AlexNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), # nn.ReLU(inplace=True), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(64, 192, kernel_size=3, padding=1), # nn.ReLU(inplace=True), nn.ReLU(), nn.MaxPool2d(kernel_size=2), nn.Conv2d(192, 384, kernel_size=3, padding=1), # nn.ReLU(inplace=True), nn.ReLU(), nn.Conv2d(384, 256, kernel_size=3, padding=1), # nn.ReLU(inplace=True), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, padding=1), # nn.ReLU(inplace=True), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.dropout = nn.Dropout() if dropout else nn.Identity() self.features_size = 256 * 4 * 4 self.fc1 = nn.Linear(self.features_size, self.features_size) if method == 'linear': self.fc = nn.Linear(self.features_size, self.features_size, bias=False) elif method == 'butterfly': self.fc = Butterfly(self.features_size, self.features_size, tied_weight=tied_weight, bias=False, **kwargs) # self.fc = Butterfly(self.features_size, self.features_size, tied_weight=False, bias=False, param='regular', nblocks=0) # self.fc = Butterfly(self.features_size, self.features_size, tied_weight=False, bias=False, param='odo', nblocks=1) elif method == 'low-rank': self.fc = nn.Sequential( nn.Linear(self.features_size, kwargs['rank'], bias=False), nn.Linear(kwargs['rank'], self.features_size, bias=False)) else: assert False, f"method {method} not supported" self.bias = nn.Parameter(torch.zeros(self.features_size)) self.fc2 = nn.Linear(4096, 4096) # self.fc2 = nn.Identity() self.classifier = nn.Sequential( # nn.Dropout(), # self.dropout, # self.fc1, # nn.ReLU(), # nn.Dropout(), self.dropout, self.fc2, nn.ReLU(), nn.Linear(self.features_size, num_classes), )
default=None, type=int, help="Camera index number") parser.add_argument("--input_node", default="", type=str, help="The name of the input node") parser.add_argument("--output_node", default="", type=str, help="The name of the output node") args = parser.parse_args() if __name__ == '__main__': # Construct a Butterfly. fly = Butterfly(args.model, args.input_node, args.output_node) # Output all the ops name in the graph. if args.list_ops: for op in fly.list_ops(): print(op) # Process an image. if args.image: image = cv2.imread(args.image) print(fly.run([image])) # Process video/cam. if args.cam is not None: video_source = args.cam elif args.video:
def test_butterfly_bmm(self): batch_size = 10 matrix_batch = 3 for device in ['cpu' ] + ([] if not torch.cuda.is_available() else ['cuda']): for in_size, out_size in [(7, 15), (15, 7)]: for complex in [False, True]: for tied_weight in [True, False]: for increasing_stride in [True, False]: for ortho_init in [False, True]: for param in ['regular'] if complex else [ 'regular', 'ortho', 'odo', 'obdobt' ]: for nblocks in [0, 1, 2, 3] if param in [ 'regular', 'ortho', 'odo', 'obdobt' ] else [0]: for expansion in [1, 2]: if param in ['obdobt' ] and tied_weight: continue if nblocks > 0 and complex: continue if not (nblocks > 0 and tied_weight and param in ['odo'] ): # Special case if nblocks > 0 and ( tied_weight or param not in [ 'regular', 'ortho', 'odo', 'obdobt' ]): continue b_bmm = ButterflyBmm( in_size, out_size, matrix_batch, True, complex, tied_weight, increasing_stride, ortho_init, param, expansion=expansion).to(device) input = torch.randn( (batch_size, matrix_batch, in_size) + (() if not complex else (2, )), device=device) output = b_bmm(input) self.assertTrue( output.shape == (batch_size, matrix_batch, out_size) + (() if not complex else (2, )), (output.shape, device, (in_size, out_size), complex, tied_weight, ortho_init)) # Check that the result is the same as looping over butterflies if param == 'regular': output_loop = [] for i in range(matrix_batch): b = Butterfly( in_size, out_size, True, complex, tied_weight, increasing_stride, ortho_init, expansion=expansion) b.twiddle = torch.nn.Parameter( b_bmm. twiddle[i * b_bmm.nstack: (i + 1) * b_bmm.nstack]) b.bias = torch.nn.Parameter( b_bmm.bias[i]) output_loop.append( b(input[:, i])) output_loop = torch.stack( output_loop, dim=1) self.assertTrue( torch.allclose( output, output_loop), ((output - output_loop ).abs().max().item(), output.shape, device, (in_size, out_size), complex, tied_weight, ortho_init)) if ortho_init and param == 'regular': twiddle_np = b_bmm.twiddle.detach( ).to('cpu').numpy() if complex: twiddle_np = twiddle_np.view( 'complex64').squeeze( -1) twiddle_np = twiddle_np.reshape( -1, 2, 2) twiddle_norm = np.linalg.norm( twiddle_np, ord=2, axis=(1, 2)) self.assertTrue( np.allclose( twiddle_norm, 1), (twiddle_norm, device, (in_size, out_size), complex, tied_weight, ortho_init))