def run_raw(in_, out_, batch_size): L = [torch.nn.Linear(in_, out_, bias=False) for _ in range(nsteps)] weights = [i.weight.t() for i in L] x = torch.randn(batch_size, in_, requires_grad=False) x_stack = x.unsqueeze(1).expand((batch_size, max(1, out_//in_), in_)) B_untied = [Butterfly(in_, out_, bias=False, tied_weight=False) for _ in range(nsteps)] twiddle_untied = [B_untied[i].twiddle for i in range(nsteps)] bfly_start = time.perf_counter() for i in range(nsteps): output = butterfly_multiply_untied(twiddle_untied[i], x_stack, True, False) bfly_end = time.perf_counter() bfly_time_train = bfly_end - bfly_start print(f'Butterfly Training Forward: {bfly_time_train}') bfly_start = time.perf_counter() for i in range(nsteps): output = butterfly_multiply_untied_eval(twiddle_untied[i], x_stack, True) bfly_end = time.perf_counter() bfly_time_eval = bfly_end - bfly_start print(f'Butterfly Inference Forward: {bfly_time_eval}') gemm_start = time.perf_counter() for i in range(nsteps): output = x.matmul(weights[i]) gemm_end = time.perf_counter() gemm_time = gemm_end - gemm_start print(f'Linear Forward: {gemm_time}') print(f'Dim: {in_, out_} Batch Size: {batch_size} Speedup: {gemm_time / bfly_time_eval}x')
def forward(ctx, twiddle, input, increasing_stride=True, is_training=True, fast=False): """ Parameters: twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex increasing_stride: whether to multiply with increasing stride (e.g. 1, 4, ..., n/2) or decreasing stride (e.g., n/2, n/4, ..., 1). Note that this only changes the order of multiplication, not how twiddle is stored. In other words, twiddle[@log_stride] always stores the twiddle for @stride. Returns: output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex """ # use optimized code for inference if not is_training and not input.is_cuda and input.dim( ) == 3 and input.dtype == torch.float and input.shape[-1] > 8: output = butterfly_multiply_untied_eval(twiddle, input, increasing_stride) else: if not fast: output = butterfly_multiply_untied(twiddle, input, increasing_stride, False) else: output = butterfly_multiply_untied_forward_fast( twiddle, input, increasing_stride) ctx.save_for_backward(twiddle, input) ctx._increasing_stride = increasing_stride ctx._fast = fast return output
def backward(ctx, grad): """ Parameters: grad: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex output + intermediate values for backward: (log n + 1, batch_size, nstack, n) if real or (log n + 1, batch_size, nstack, n, 2) if complex Return: d_twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex d_input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex """ # twiddle, output_and_intermediate = ctx.saved_tensors twiddle, input = ctx.saved_tensors increasing_stride = ctx._increasing_stride fast = ctx._fast n = input.shape[2] if input.dim() == 3 and n <= 1024 and input.is_cuda: if not fast: d_coefficients, d_input = butterfly_multiply_untied_forward_backward( twiddle, input, grad, increasing_stride) else: d_coefficients, d_input = butterfly_multiply_untied_forward_backward_fast( twiddle, input, grad, increasing_stride) else: output_and_intermediate = butterfly_multiply_untied( twiddle, input, increasing_stride, True) d_coefficients, d_input = butterfly_multiply_untied_backward( grad, twiddle, output_and_intermediate, increasing_stride) return d_coefficients, d_input, None, None, None # Autograd requires 3 gradients
def forward(ctx, twiddle, input, increasing_stride=True): """ Parameters: twiddle: (nstack, log 2, n / 2, 2, 2) if real or (nstack, log 2, n / 2, 2, 2, 2) if complex input: (batch_size, n) if real or (batch_size, n, 2) if complex increasing_stride: whether to multiply with increasing stride (e.g. 2, 4, ..., n/2) or decreasing stride (e.g., n/2, n/4, ..., 2). Note that this only changes the order of multiplication, not how twiddle is stored. In other words, twiddle[@log_stride] always stores the twiddle for @stride. Returns: output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex """ output_and_intermediate = butterfly_multiply_untied( twiddle, input, increasing_stride) ctx.save_for_backward(twiddle, output_and_intermediate) ctx._increasing_stride = increasing_stride return output_and_intermediate[-1]