def run_raw(in_, out_, batch_size):
    L = [torch.nn.Linear(in_, out_, bias=False) for _ in range(nsteps)]
    weights = [i.weight.t() for i in L]
    x = torch.randn(batch_size, in_, requires_grad=False)
    x_stack = x.unsqueeze(1).expand((batch_size, max(1, out_//in_), in_))
    B_untied = [Butterfly(in_, out_, bias=False, tied_weight=False) for _ in range(nsteps)]
    twiddle_untied = [B_untied[i].twiddle for i in range(nsteps)]

    bfly_start = time.perf_counter()
    for i in range(nsteps):
        output = butterfly_multiply_untied(twiddle_untied[i], x_stack, True, False)
    bfly_end = time.perf_counter()
    bfly_time_train = bfly_end - bfly_start
    print(f'Butterfly Training Forward: {bfly_time_train}')

    bfly_start = time.perf_counter()
    for i in range(nsteps):
        output = butterfly_multiply_untied_eval(twiddle_untied[i], x_stack, True)
    bfly_end = time.perf_counter()
    bfly_time_eval = bfly_end - bfly_start
    print(f'Butterfly Inference Forward: {bfly_time_eval}')

    gemm_start = time.perf_counter()
    for i in range(nsteps):
        output = x.matmul(weights[i])
    gemm_end = time.perf_counter()
    gemm_time = gemm_end - gemm_start
    print(f'Linear Forward: {gemm_time}')

    print(f'Dim: {in_, out_} Batch Size: {batch_size} Speedup: {gemm_time / bfly_time_eval}x')
 def forward(ctx,
             twiddle,
             input,
             increasing_stride=True,
             is_training=True,
             fast=False):
     """
     Parameters:
         twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex
         input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
         increasing_stride: whether to multiply with increasing stride (e.g. 1, 4, ..., n/2) or
             decreasing stride (e.g., n/2, n/4, ..., 1).
             Note that this only changes the order of multiplication, not how twiddle is stored.
             In other words, twiddle[@log_stride] always stores the twiddle for @stride.
     Returns:
         output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
     """
     # use optimized code for inference
     if not is_training and not input.is_cuda and input.dim(
     ) == 3 and input.dtype == torch.float and input.shape[-1] > 8:
         output = butterfly_multiply_untied_eval(twiddle, input,
                                                 increasing_stride)
     else:
         if not fast:
             output = butterfly_multiply_untied(twiddle, input,
                                                increasing_stride, False)
         else:
             output = butterfly_multiply_untied_forward_fast(
                 twiddle, input, increasing_stride)
     ctx.save_for_backward(twiddle, input)
     ctx._increasing_stride = increasing_stride
     ctx._fast = fast
     return output
 def backward(ctx, grad):
     """
     Parameters:
         grad: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
         twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex
         output + intermediate values for backward: (log n + 1, batch_size, nstack, n) if real or (log n + 1, batch_size, nstack, n, 2) if complex
     Return:
         d_twiddle: (nstack, log n, n / 2, 2, 2) if real or (nstack, log n, n / 2, 2, 2, 2) if complex
         d_input: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
     """
     # twiddle, output_and_intermediate = ctx.saved_tensors
     twiddle, input = ctx.saved_tensors
     increasing_stride = ctx._increasing_stride
     fast = ctx._fast
     n = input.shape[2]
     if input.dim() == 3 and n <= 1024 and input.is_cuda:
         if not fast:
             d_coefficients, d_input = butterfly_multiply_untied_forward_backward(
                 twiddle, input, grad, increasing_stride)
         else:
             d_coefficients, d_input = butterfly_multiply_untied_forward_backward_fast(
                 twiddle, input, grad, increasing_stride)
     else:
         output_and_intermediate = butterfly_multiply_untied(
             twiddle, input, increasing_stride, True)
         d_coefficients, d_input = butterfly_multiply_untied_backward(
             grad, twiddle, output_and_intermediate, increasing_stride)
     return d_coefficients, d_input, None, None, None  # Autograd requires 3 gradients
 def forward(ctx, twiddle, input, increasing_stride=True):
     """
     Parameters:
         twiddle: (nstack, log 2, n / 2, 2, 2) if real or (nstack, log 2, n / 2, 2, 2, 2) if complex
         input: (batch_size, n) if real or (batch_size, n, 2) if complex
         increasing_stride: whether to multiply with increasing stride (e.g. 2, 4, ..., n/2) or
             decreasing stride (e.g., n/2, n/4, ..., 2).
             Note that this only changes the order of multiplication, not how twiddle is stored.
             In other words, twiddle[@log_stride] always stores the twiddle for @stride.
     Returns:
         output: (batch_size, nstack, n) if real or (batch_size, nstack, n, 2) if complex
     """
     output_and_intermediate = butterfly_multiply_untied(
         twiddle, input, increasing_stride)
     ctx.save_for_backward(twiddle, output_and_intermediate)
     ctx._increasing_stride = increasing_stride
     return output_and_intermediate[-1]