Example #1
0
def input_extension_fft_test(half_inputs, domain, even_outputs):
    modulus = 337

    inverse_of_2 = modular_inverse(2, modulus)

    assert len(half_inputs) * 2 == len(even_outputs) * 2 == len(domain)
    inverse_domain = [modular_inverse(d, modulus) for d in domain]

    resolved_second_half_inputs = input_extension_fft(half_inputs,
                                                      even_outputs, modulus,
                                                      domain, inverse_domain,
                                                      inverse_of_2)
    print("resolved_second_half_inputs", resolved_second_half_inputs)

    reconstructed_inputs = half_inputs + resolved_second_half_inputs
    print("reconstructed_inputs", reconstructed_inputs)
    assert fft(reconstructed_inputs, modulus, domain)[::2] == even_outputs
def output_extension_fft_test(half_inputs, domain, even_outputs):
    modulus = 337

    inverse_of_2 = modular_inverse(2, modulus)

    assert len(half_inputs) * 2 == len(even_outputs) * 2 == len(domain)
    inverse_domain = [modular_inverse(d, modulus) for d in domain]

    resolved_odd_outputs = output_extension_fft(half_inputs, even_outputs,
                                                modulus, domain,
                                                inverse_domain, inverse_of_2)
    print("resolved_odd_outputs", resolved_odd_outputs)

    reconstructed_outputs = [
        even_outputs[i // 2] if i % 2 == 0 else resolved_odd_outputs[i // 2]
        for i in range(len(even_outputs) + len(resolved_odd_outputs))
    ]
    print("reconstructed_outputs", reconstructed_outputs)
    assert inverse_fft(reconstructed_outputs, modulus,
                       domain)[:len(half_inputs)] == half_inputs
Example #3
0
def das_fft_test(domain, even_outputs):
    modulus = 337

    inverse_of_2 = modular_inverse(2, modulus)

    assert len(even_outputs) * 2 == len(domain)
    inverse_domain = [modular_inverse(d, modulus) for d in domain]

    half = len(even_outputs)

    resolved_odd_outputs = das_fft_wrapper(even_outputs, modulus, domain, inverse_domain)
    print("resolved_odd_outputs", resolved_odd_outputs)

    reconstructed_outputs = [even_outputs[i // 2] if i % 2 == 0 else resolved_odd_outputs[i // 2] for i in range(2*half)]
    print("reconstructed_outputs", reconstructed_outputs)
    reconstructed_inputs = inverse_fft(reconstructed_outputs, modulus, domain)
    print("reconstructed_inputs", reconstructed_inputs)

    assert reconstructed_inputs[half:] == [0] * half
    assert fft(reconstructed_inputs, modulus, domain) == reconstructed_outputs
Example #4
0
def input_odd_extension_fft(even_vals, half_out, modulus, domain):
    L = fft(even_vals, modulus, domain[::2])
    # R = fft(vals[1::2], modulus, domain[::2])
    R = [0 for i in L]
    for i, x in enumerate(L):
        # y = R[i]
        # y_times_root = y * domain[i]
        # o[i] = (x + y_times_root) % modulus
        # o[i + len(L)] = (x - y_times_root) % modulus == half_out[i]
        # x == half_out[i] + y_times_root == half_out[i] + y * domain[i]
        # R[i] = y = (x - half_out[i]) * inv_domain[i]
        assert modular_inverse(domain[i], modulus) == domain[-i]
        R[i] = ((x - half_out[i]) * domain[-i]) % modulus
    odd_values = inverse_fft(R, modulus, domain[::2])
    return odd_values
Example #5
0
def partial_fft_test(half_inputs, domain, even_outputs):
    modulus = 337

    inverse_of_2 = modular_inverse(2, modulus)

    assert len(half_inputs) * 2 == len(even_outputs) * 2 == len(domain)
    inverse_domain = [modular_inverse(d, modulus) for d in domain]

    resolved_second_half_inputs, resolved_odd_outputs = partial_fft(
        half_inputs, even_outputs, modulus, domain, inverse_domain,
        inverse_of_2)
    print("resolved_second_half_inputs", resolved_second_half_inputs)
    print("resolved_odd_outputs", resolved_odd_outputs)

    reconstructed_inputs = half_inputs + resolved_second_half_inputs
    reconstructed_outputs = [
        even_outputs[i // 2] if i % 2 == 0 else resolved_odd_outputs[i // 2]
        for i in range(len(even_outputs) + len(resolved_odd_outputs))
    ]
    print("reconstructed_inputs", reconstructed_inputs)
    print("reconstructed_outputs", reconstructed_outputs)
    assert fft(reconstructed_inputs, modulus, domain) == reconstructed_outputs
    assert inverse_fft(reconstructed_outputs, modulus,
                       domain) == reconstructed_inputs
Example #6
0
    262144:
    20439484849038267462774237595151440867617792718791690563928621375157525968123,
    524288:
    37115000097562964541269718788523040559386243094666416358585267518228781043101,
}

WIDTH = 256

from das_fft import das_fft
from classic_fft import modular_inverse

MODULUS = 52435875175126190479447740508185965837690552500527637822603658699938581184513
ROOT_OF_UNITY = rootOfUnityCandidates[WIDTH]
domain = expand_root_of_unity(ROOT_OF_UNITY, MODULUS)[:-1]

inverse_domain = [modular_inverse(d, MODULUS) for d in domain]
inverse_of_2 = modular_inverse(2, MODULUS)

# even data will be the original data of interest
even_data = [i * 42 % MODULUS for i in range(WIDTH // 2)]
debug_bigs("even_data", even_data)

odd_data = das_fft(even_data, MODULUS, domain, inverse_domain, inverse_of_2)
debug_bigs("odd data", odd_data)

extended_data = [
    even_data[i // 2] if i % 2 == 0 else odd_data[i // 2] for i in range(WIDTH)
]
debug_bigs("extended data", extended_data)

# The odd_data was constructed in such a way that the second half of coefficients are all zero.
Example #7
0
def bench(scale: int):
    width = 2**scale

    modulus = MODULUS
    root_of_unity = rootOfUnityCandidates[width]
    domain = expand_root_of_unity(root_of_unity, modulus)[:-1]
    assert len(domain) == width
    inverse_domain = [modular_inverse(d, modulus) for d in domain]
    inverse_of_2 = modular_inverse(2, modulus)

    data = [i for i in range(width)]
    N = 200
    start = time.time()
    for i in range(N):
        output = fft(data, modulus, domain)
        assert len(output) == len(data)
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("          FFT                 : scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = data[:width//2]
    even_outputs = [0] * (width//2)
    start = time.time()
    for i in range(N):
        other_inputs, other_outputs = partial_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_inputs) == width//2
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("  Partial FFT                 : scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = data[width//2:]
    even_outputs = [0] * (width//2)
    start = time.time()
    for i in range(N):
        other_inputs, other_outputs = other_partial_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_inputs) == width//2
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("OtherPart FFT                 : scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = data[:width//2]
    even_outputs = [0] * (width//2)
    start = time.time()
    for i in range(N):
        other_inputs = input_extension_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_inputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print(" InputExt FFT (zeroed outputs): scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = data[:width//2]
    even_outputs = [0] * (width//2)
    start = time.time()
    for i in range(N):
        other_outputs = output_extension_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("OutputExt FFT (zeroed outputs): scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = [0] * (width//2)
    even_outputs = data[::2]
    start = time.time()
    for i in range(N):
        other_inputs = input_extension_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_inputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print(" InputExt FFT  (zeroed inputs): scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    partial_data = [0] * (width//2)
    even_outputs = data[::2]
    start = time.time()
    for i in range(N):
        other_outputs = output_extension_fft(partial_data, even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("OutputExt FFT  (zeroed inputs): scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    even_outputs = list(data[::2])
    start = time.time()
    for i in range(N):
        other_outputs = das_fft(even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("      DAS FFT                 : scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))

    even_outputs = [0] * (width//2)
    start = time.time()
    for i in range(N):
        other_outputs = das_fft(even_outputs, modulus, domain, inverse_domain, inverse_of_2)
        assert len(other_outputs) == width//2
    end = time.time()
    diff = end-start
    ns = diff*1e9
    print("      DAS FFT (zeroed outputs): scale_%-5d %10d ops %15.0f ns/op" % (scale, N, ns/N))
Example #8
0
def das_fft_wrapper(a: list, modulus: int, domain: list, inverse_domain: list):
    invlen = modular_inverse(len(a), modulus)
    b = das_fft(a, modulus, domain, inverse_domain)
    out = [(v*invlen)%modulus for v in b]
    return out