def test_convolution(): # fft a = [1, S(5) / 3, sqrt(3), S(7) / 5] b = [9, 5, 5, 4, 3, 2] c = [3, 5, 3, 7, 8] d = [1422, 6572, 3213, 5552] assert convolution(a, b, fft=True) == convolution_fft(a, b) assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9) assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7) assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3) # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7 * 17 * 2**23 + 1 q = 19 * 2**10 + 1 # ntt assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q) assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p) assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p) raises(TypeError, lambda: convolution(b, d, ntt=True)) raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0)) raises(TypeError, lambda: convolution(b, d, dps=5, prime=q)) raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q)) raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q)) # ntt is a specialized variant of fft, TypeError should not be raised assert convolution(b, d, fft=True, ntt=True, prime=q) == \ convolution_ntt(b, d, prime=q) # fwht assert convolution(a, b, dyadic=True) == convolution_fwht(a, b) assert convolution(a, b, dyadic=False) == convolution(a, b) raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, d, ntt=True, prime=p, dyadic=True)) raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True)) raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True)) # subset assert convolution(a, b, subset=True) == convolution_subset(a, b) == \ convolution(a, b, subset=True, dyadic=False) == \ convolution(a, b, subset=True, fft=False) == \ convolution(a, b, subset=True, fft=False, ntt=False) assert convolution(a, b, subset=False) == convolution(a, b) raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True)) raises(TypeError, lambda: convolution(b, c, subset=True, fft=True)) raises(TypeError, lambda: convolution(c, d, subset=True, dps=6)) raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
def test_convolution_subset(): assert convolution_subset([], []) == [] assert convolution_subset([], [S(1) / 3]) == [] assert convolution_subset([6 + 3 * I / 7], [S(2) / 3]) == [4 + 2 * I / 7] a = [1, S(5) / 3, sqrt(3), 4 + 5 * I] b = [64, 71, 55, 47, 33, 29, 15] c = [3 + 2 * I / 3, 5 + 7 * I, 7, S(7) / 5, 9] assert convolution_subset(a, b) == [ 64, 533 / S(3), 55 + 64 * sqrt(3), 71 * sqrt(3) + 1184 / S(3) + 320 * I, 33, 84, 15 + 33 * sqrt(3), 29 * sqrt(3) + 157 + 165 * I ] assert convolution_subset(b, c) == [ 192 + 128 * I / 3, 533 + 1486 * I / 3, 613 + 110 * I / 3, 5013 / 5 + 1249 * I / 3, 675 + 22 * I, 891 + 751 * I / 3, 771 + 10 * I, 3736 / 5 + 105 * I ] assert convolution_subset(a, c) == convolution_subset(c, a) assert convolution_subset(a[:2], b) == \ [64, 533/S(3), 55, 416/S(3), 33, 84, 15, 25] assert convolution_subset(a[:2], c) == \ [3 + 2*I/3, 10 + 73*I/9, 7, 196/S(15), 9, 15, 0, 0] u, v, w, x, y, z = symbols('u v w x y z') assert convolution_subset([u, v, w], [x, y]) == [u * x, u * y + v * x, w * x, w * y] assert convolution_subset([u, v, w, x], [y, z]) == \ [u*y, u*z + v*y, w*y, w*z + x*y] assert convolution_subset([u, v], [x, y, z]) == \ convolution_subset([x, y, z], [u, v]) raises(TypeError, lambda: convolution_subset(x, z)) raises(TypeError, lambda: convolution_subset(S(7) / 3, u))