Example #1
0
def test_convolution():
    # fft
    a = [1, S(5) / 3, sqrt(3), S(7) / 5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b, fft=True) == convolution_fft(a, b)
    assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7 * 17 * 2**23 + 1
    q = 19 * 2**10 + 1

    # ntt
    assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b,
                                                                   d,
                                                                   prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c,
                                                                   d,
                                                                   prime=p)
    raises(TypeError, lambda: convolution(b, d, ntt=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0))
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q))
    raises(TypeError,
           lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q))
    # ntt is a specialized variant of fft, TypeError should not be raised
    assert convolution(b, d, fft=True, ntt=True, prime=q) == \
            convolution_ntt(b, d, prime=q)

    # fwht
    assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
    assert convolution(a, b, dyadic=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True))
    raises(TypeError,
           lambda: convolution(b, d, ntt=True, prime=p, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True))
    raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))

    # subset
    assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
            convolution(a, b, subset=True, dyadic=False) == \
                convolution(a, b, subset=True, fft=False) == \
                    convolution(a, b, subset=True, fft=False, ntt=False)
    assert convolution(a, b, subset=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, subset=True, fft=True))
    raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
    raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
def test_convolution():
    # fft
    a = [1, S(5)/3, sqrt(3), S(7)/5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b, fft=True) == convolution_fft(a, b)
    assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7*17*2**23 + 1
    q = 19*2**10 + 1

    # ntt
    assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p)
    raises(TypeError, lambda: convolution(b, d, ntt=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0))
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q))
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q))
    # ntt is a specialized variant of fft, TypeError should not be raised
    assert convolution(b, d, fft=True, ntt=True, prime=q) == \
            convolution_ntt(b, d, prime=q)

    # fwht
    assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
    assert convolution(a, b, dyadic=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, prime=p, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True))
    raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
def test_convolution_fwht():
    assert convolution_fwht([], []) == []
    assert convolution_fwht([], [1]) == []
    assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27]

    assert convolution_fwht([S(5)/7, S(6)/8, S(7)/3], [2, 4, S(6)/7]) == \
                                    [S(45)/7, S(61)/14, S(776)/147, S(419)/42]

    a = [1, S(5)/3, sqrt(3), S(7)/5, 4 + 5*I]
    b = [94, 51, 53, 45, 31, 27, 13]
    c = [3 + 4*I, 5 + 7*I, 3, S(7)/6, 8]

    assert convolution_fwht(a, b) == [53*sqrt(3) + 366 + 155*I,
                                    45*sqrt(3) + 5848/15 + 135*I,
                                    94*sqrt(3) + 1257/5 + 65*I,
                                    51*sqrt(3) + 3974/15,
                                    13*sqrt(3) + 452 + 470*I,
                                    4513/15 + 255*I,
                                    31*sqrt(3) + 1314/5 + 265*I,
                                    27*sqrt(3) + 3676/15 + 225*I]

    assert convolution_fwht(b, c) == [1993/S(2) + 733*I, 6215/S(6) + 862*I,
        1659/S(2) + 527*I, 1988/S(3) + 551*I, 1019 + 313*I, 3955/S(6) + 325*I,
        1175/S(2) + 52*I, 3253/S(6) + 91*I]

    assert convolution_fwht(a[3:], c) == [-54/5 + 293*I/5, -1 + 204*I/5,
            133/S(15) + 35*I/6, 409/S(30) + 15*I, 56/S(5), 32 + 40*I, 0, 0]

    u, v, w, x, y, z = symbols('u v w x y z')

    assert convolution_fwht([u, v], [x, y]) == [u*x + v*y, u*y + v*x]

    assert convolution_fwht([u, v, w], [x, y]) == \
        [u*x + v*y, u*y + v*x, w*x, w*y]

    assert convolution_fwht([u, v, w], [x, y, z]) == \
        [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y]

    raises(TypeError, lambda: convolution_fwht(x, y))
    raises(TypeError, lambda: convolution_fwht(x*y, u + v))
Example #4
0
def test_convolution_fwht():
    assert convolution_fwht([], []) == []
    assert convolution_fwht([], [1]) == []
    assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27]

    assert convolution_fwht([S(5)/7, S(6)/8, S(7)/3], [2, 4, S(6)/7]) == \
                                    [S(45)/7, S(61)/14, S(776)/147, S(419)/42]

    a = [1, S(5) / 3, sqrt(3), S(7) / 5, 4 + 5 * I]
    b = [94, 51, 53, 45, 31, 27, 13]
    c = [3 + 4 * I, 5 + 7 * I, 3, S(7) / 6, 8]

    assert convolution_fwht(a, b) == [
        53 * sqrt(3) + 366 + 155 * I, 45 * sqrt(3) + 5848 / 15 + 135 * I,
        94 * sqrt(3) + 1257 / 5 + 65 * I, 51 * sqrt(3) + 3974 / 15,
        13 * sqrt(3) + 452 + 470 * I, 4513 / 15 + 255 * I,
        31 * sqrt(3) + 1314 / 5 + 265 * I, 27 * sqrt(3) + 3676 / 15 + 225 * I
    ]

    assert convolution_fwht(b, c) == [
        1993 / S(2) + 733 * I, 6215 / S(6) + 862 * I, 1659 / S(2) + 527 * I,
        1988 / S(3) + 551 * I, 1019 + 313 * I, 3955 / S(6) + 325 * I,
        1175 / S(2) + 52 * I, 3253 / S(6) + 91 * I
    ]

    assert convolution_fwht(a[3:], c) == [
        -54 / 5 + 293 * I / 5, -1 + 204 * I / 5, 133 / S(15) + 35 * I / 6,
        409 / S(30) + 15 * I, 56 / S(5), 32 + 40 * I, 0, 0
    ]

    u, v, w, x, y, z = symbols('u v w x y z')

    assert convolution_fwht([u, v], [x, y]) == [u * x + v * y, u * y + v * x]

    assert convolution_fwht([u, v, w], [x, y]) == \
        [u*x + v*y, u*y + v*x, w*x, w*y]

    assert convolution_fwht([u, v, w], [x, y, z]) == \
        [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y]

    raises(TypeError, lambda: convolution_fwht(x, y))
    raises(TypeError, lambda: convolution_fwht(x * y, u + v))