Ejemplo n.º 1
0
def test_convolution():
    a = [1, S(5)/3, sqrt(3), S(7)/5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b, fft=True) == convolution_fft(a, b)
    assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7*17*2**23 + 1
    q = 19*2**10 + 1
    assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p)

    raises(TypeError, lambda: convolution(b, d, ntt=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0))
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q))
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q))

    # ntt is a specialized variant of fft, TypeError should not be raised
    assert convolution(b, d, fft=True, ntt=True, prime=q) == \
            convolution_ntt(b, d, prime=q)
Ejemplo n.º 2
0
def test_convolution():
    # fft
    a = [1, S(5) / 3, sqrt(3), S(7) / 5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b) == convolution_fft(a, b)
    assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7 * 17 * 2**23 + 1
    q = 19 * 2**10 + 1

    # ntt
    assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p)
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, prime=q))

    # fwht
    assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
    assert convolution(a, b, dyadic=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True))
    raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))

    # subset
    assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
            convolution(a, b, subset=True, dyadic=False) == \
                convolution(a, b, subset=True)
    assert convolution(a, b, subset=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
    raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
    raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
Ejemplo n.º 3
0
def test_convolution():
    # fft
    a = [1, S(5)/3, sqrt(3), S(7)/5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b, fft=True) == convolution_fft(a, b)
    assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7*17*2**23 + 1
    q = 19*2**10 + 1

    # ntt
    assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p)
    raises(TypeError, lambda: convolution(b, d, ntt=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0))
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q))
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q))
    # ntt is a specialized variant of fft, TypeError should not be raised
    assert convolution(b, d, fft=True, ntt=True, prime=q) == \
            convolution_ntt(b, d, prime=q)

    # fwht
    assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
    assert convolution(a, b, dyadic=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, prime=p, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True))
    raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
Ejemplo n.º 4
0
def test_convolution_fft():
    assert all(convolution_fft([], x, dps=y) == [] for x in ([], [1]) for y in (None, 3))
    assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
    assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7]
    assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21]

    assert convolution_fft([1 + 2*I], [2 + 3*I]) == [-4 + 7*I]
    assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + S(3)/5*I], [S(2)/5 + S(4)/7*I]) == \
            [-S(26)/35 + 48*I/35, -S(38)/35 + 116*I/35, S(58)/35 + 542*I/175]

    assert convolution_fft([S(3)/4, S(5)/6], [S(7)/8, S(1)/3, S(2)/5]) == [S(21)/32,
                                                S(47)/48, S(26)/45, S(1)/3]
    assert convolution_fft([S(1)/9, S(2)/3, S(3)/5], [S(2)/5, S(3)/7, S(4)/9]) == [S(2)/45,
                                    S(11)/35, S(8152)/14175, S(523)/945, S(4)/15]
    assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == [sqrt(3)*pi,
                                                            1 + sqrt(3)*E,
                                                            E/pi + pi*exp(-1) + sqrt(6),
                                                            sqrt(2)/pi + 1,
                                                            sqrt(2)*exp(-1)]

    assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == [12350041, 190918524,
                                                        374911166, 2362431729]
    assert convolution_fft([312313, 31278232], [32139631, 319631]) == [10037624576503,
                                                1005370659728895, 9997492572392]

    raises(TypeError, lambda: convolution_fft(x, y))
    raises(ValueError, lambda: convolution_fft([x, y], [y, x]))
Ejemplo n.º 5
0
def test_convolution_fft():
    assert all(
        convolution_fft([], x, dps=y) == [] for x in ([], [1])
        for y in (None, 3))
    assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
    assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7]
    assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21]

    assert convolution_fft([1 + 2 * I], [2 + 3 * I]) == [-4 + 7 * I]
    assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + S(3)/5*I], [S(2)/5 + S(4)/7*I]) == \
            [-S(26)/35 + 48*I/35, -S(38)/35 + 116*I/35, S(58)/35 + 542*I/175]

    assert convolution_fft([S(3)/4, S(5)/6], [S(7)/8, S(1)/3, S(2)/5]) == \
                                    [S(21)/32, S(47)/48, S(26)/45, S(1)/3]

    assert convolution_fft([S(1)/9, S(2)/3, S(3)/5], [S(2)/5, S(3)/7, S(4)/9]) == \
                                [S(2)/45, S(11)/35, S(8152)/14175, S(523)/945, S(4)/15]

    assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == \
                    [sqrt(3)*pi, 1 + sqrt(3)*E, E/pi + pi*exp(-1) + sqrt(6),
                                            sqrt(2)/pi + 1, sqrt(2)*exp(-1)]

    assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == \
                        [12350041, 190918524, 374911166, 2362431729]

    assert convolution_fft([312313, 31278232], [32139631, 319631]) == \
                        [10037624576503, 1005370659728895, 9997492572392]

    raises(TypeError, lambda: convolution_fft(x, y))
    raises(ValueError, lambda: convolution_fft([x, y], [y, x]))