コード例 #1
0
def test_convolution():
    # fft
    a = [1, S(5) / 3, sqrt(3), S(7) / 5]
    b = [9, 5, 5, 4, 3, 2]
    c = [3, 5, 3, 7, 8]
    d = [1422, 6572, 3213, 5552]

    assert convolution(a, b, fft=True) == convolution_fft(a, b)
    assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9)
    assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7)
    assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)

    # prime moduli of the form (m*2**k + 1), sequence length
    # should be a divisor of 2**k
    p = 7 * 17 * 2**23 + 1
    q = 19 * 2**10 + 1

    # ntt
    assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b,
                                                                   d,
                                                                   prime=q)
    assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
    assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c,
                                                                   d,
                                                                   prime=p)
    raises(TypeError, lambda: convolution(b, d, ntt=True))
    raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0))
    raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
    raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q))
    raises(TypeError,
           lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q))
    # ntt is a specialized variant of fft, TypeError should not be raised
    assert convolution(b, d, fft=True, ntt=True, prime=q) == \
            convolution_ntt(b, d, prime=q)

    # fwht
    assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
    assert convolution(a, b, dyadic=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True))
    raises(TypeError,
           lambda: convolution(b, d, ntt=True, prime=p, dyadic=True))
    raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True))
    raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))

    # subset
    assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
            convolution(a, b, subset=True, dyadic=False) == \
                convolution(a, b, subset=True, fft=False) == \
                    convolution(a, b, subset=True, fft=False, ntt=False)
    assert convolution(a, b, subset=False) == convolution(a, b)
    raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
    raises(TypeError, lambda: convolution(b, c, subset=True, fft=True))
    raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
    raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
コード例 #2
0
ファイル: test_convolution.py プロジェクト: hdkjain/sympy
def test_convolution_subset():
    assert convolution_subset([], []) == []
    assert convolution_subset([], [S(1) / 3]) == []
    assert convolution_subset([6 + 3 * I / 7], [S(2) / 3]) == [4 + 2 * I / 7]

    a = [1, S(5) / 3, sqrt(3), 4 + 5 * I]
    b = [64, 71, 55, 47, 33, 29, 15]
    c = [3 + 2 * I / 3, 5 + 7 * I, 7, S(7) / 5, 9]

    assert convolution_subset(a, b) == [
        64, 533 / S(3), 55 + 64 * sqrt(3),
        71 * sqrt(3) + 1184 / S(3) + 320 * I, 33, 84, 15 + 33 * sqrt(3),
        29 * sqrt(3) + 157 + 165 * I
    ]

    assert convolution_subset(b, c) == [
        192 + 128 * I / 3, 533 + 1486 * I / 3, 613 + 110 * I / 3,
        5013 / 5 + 1249 * I / 3, 675 + 22 * I, 891 + 751 * I / 3, 771 + 10 * I,
        3736 / 5 + 105 * I
    ]

    assert convolution_subset(a, c) == convolution_subset(c, a)
    assert convolution_subset(a[:2], b) == \
            [64, 533/S(3), 55, 416/S(3), 33, 84, 15, 25]

    assert convolution_subset(a[:2], c) == \
            [3 + 2*I/3, 10 + 73*I/9, 7, 196/S(15), 9, 15, 0, 0]

    u, v, w, x, y, z = symbols('u v w x y z')

    assert convolution_subset([u, v, w],
                              [x, y]) == [u * x, u * y + v * x, w * x, w * y]
    assert convolution_subset([u, v, w, x], [y, z]) == \
                            [u*y, u*z + v*y, w*y, w*z + x*y]

    assert convolution_subset([u, v], [x, y, z]) == \
                    convolution_subset([x, y, z], [u, v])

    raises(TypeError, lambda: convolution_subset(x, z))
    raises(TypeError, lambda: convolution_subset(S(7) / 3, u))