Exemplo n.º 1
0
 def forward(ctx, twiddle, input, increasing_stride=True):
     """Experimental in-place implementation that does not store intermediate results.
     Instead, the intermediate results are computed from the output during the backward pass.
     Parameters:
         twiddle: (n - 1, 2, 2) if real or (n - 1, 2, 2, 2) if complex
         input: (batch_size, n) if real or (batch_size, n, 2) if complex
         increasing_stride: whether to multiply with increasing stride (e.g. 2, 4, ..., n/2) or
             decreasing stride (e.g., n/2, n/4, ..., 2).
             Note that this only changes the order of multiplication, not how twiddle is stored.
             In other words, twiddle[@log_stride] always stores the twiddle for @stride.
     Returns:
         output: (batch_size, n) if real or (batch_size, n, 2) if complex
     """
     assert increasing_stride, 'Decreasing stride not implemented'
     output = butterfly_multiply_inplace(twiddle, input)
     ctx.save_for_backward(twiddle, output)
     return output
Exemplo n.º 2
0
    [dst(x) for _ in range(ntrial)]
    end = timer()
    dst_times[idx_n] = (end-start) / ntrial

    # BP
    start = timer()
    [BP_mul_cy_inplace(ABCDs, perm, x) for _ in range(ntrial)]
    end = timer()
    bp_times[idx_n] = (end-start) / ntrial

    # BP_inplace_all
    twiddles = twiddle_list_concat(B)
    x_torch = torch.tensor(x).unsqueeze(0)
    # perm_torch = torch.tensor(perm)
    start = timer()
    [butterfly_multiply_inplace(twiddles, x_torch) for _ in range(ntrial)]
    end = timer()
    bp_all_times[idx_n] = (end-start) / ntrial

print(dense_times)
print(fft_times)
print(scipyfft_times)
print(dct_times)
print(dst_times)
print(bp_times)
print(bp_all_times)

# print(bp_times / fft_times)
# print(bp_times / dct_times)
# print(bp_times / dst_times)
# print(bp_times / bp_all_times)
Exemplo n.º 3
0
 def forward(ctx, twiddle, input):
     output = butterfly_multiply_inplace(twiddle, input)
     ctx.save_for_backward(twiddle, output)
     return output