def test_convolution(): a = [1, S(5)/3, sqrt(3), S(7)/5] b = [9, 5, 5, 4, 3, 2] c = [3, 5, 3, 7, 8] d = [1422, 6572, 3213, 5552] assert convolution(a, b, fft=True) == convolution_fft(a, b) assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9) assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7) assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3) # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7*17*2**23 + 1 q = 19*2**10 + 1 assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q) assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p) assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p) raises(TypeError, lambda: convolution(b, d, ntt=True)) raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0)) raises(TypeError, lambda: convolution(b, d, dps=5, prime=q)) raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q)) raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q)) # ntt is a specialized variant of fft, TypeError should not be raised assert convolution(b, d, fft=True, ntt=True, prime=q) == \ convolution_ntt(b, d, prime=q)
def test_convolution(): # fft a = [1, S(5) / 3, sqrt(3), S(7) / 5] b = [9, 5, 5, 4, 3, 2] c = [3, 5, 3, 7, 8] d = [1422, 6572, 3213, 5552] assert convolution(a, b) == convolution_fft(a, b) assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9) assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7) assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3) # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7 * 17 * 2**23 + 1 q = 19 * 2**10 + 1 # ntt assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q) assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p) assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p) raises(TypeError, lambda: convolution(b, d, dps=5, prime=q)) raises(TypeError, lambda: convolution(b, d, dps=6, prime=q)) # fwht assert convolution(a, b, dyadic=True) == convolution_fwht(a, b) assert convolution(a, b, dyadic=False) == convolution(a, b) raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True)) raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True)) # subset assert convolution(a, b, subset=True) == convolution_subset(a, b) == \ convolution(a, b, subset=True, dyadic=False) == \ convolution(a, b, subset=True) assert convolution(a, b, subset=False) == convolution(a, b) raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True)) raises(TypeError, lambda: convolution(c, d, subset=True, dps=6)) raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
def test_convolution(): # fft a = [1, S(5)/3, sqrt(3), S(7)/5] b = [9, 5, 5, 4, 3, 2] c = [3, 5, 3, 7, 8] d = [1422, 6572, 3213, 5552] assert convolution(a, b, fft=True) == convolution_fft(a, b) assert convolution(a, b, dps=9, fft=True) == convolution_fft(a, b, dps=9) assert convolution(a, d, fft=True, dps=7) == convolution_fft(d, a, dps=7) assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3) # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7*17*2**23 + 1 q = 19*2**10 + 1 # ntt assert convolution(d, b, ntt=True, prime=q) == convolution_ntt(b, d, prime=q) assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p) assert convolution(d, c, prime=p, ntt=True) == convolution_ntt(c, d, prime=p) raises(TypeError, lambda: convolution(b, d, ntt=True)) raises(TypeError, lambda: convolution(b, d, ntt=True, cycle=0)) raises(TypeError, lambda: convolution(b, d, dps=5, prime=q)) raises(TypeError, lambda: convolution(b, d, dps=6, ntt=True, prime=q)) raises(TypeError, lambda: convolution(b, d, fft=True, dps=7, ntt=True, prime=q)) # ntt is a specialized variant of fft, TypeError should not be raised assert convolution(b, d, fft=True, ntt=True, prime=q) == \ convolution_ntt(b, d, prime=q) # fwht assert convolution(a, b, dyadic=True) == convolution_fwht(a, b) assert convolution(a, b, dyadic=False) == convolution(a, b) raises(TypeError, lambda: convolution(b, d, fft=True, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, d, ntt=True, prime=p, dyadic=True)) raises(TypeError, lambda: convolution(b, d, fft=True, dyadic=True)) raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True)) raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
def test_convolution_ntt(): # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7*17*2**23 + 1 q = 19*2**10 + 1 r = 2*500000003 + 1 # only for sequences of length 1 or 2 s = 2*3*5*7 # composite modulus assert all(convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r)) assert convolution_ntt([2], [3], r) == [6] assert convolution_ntt([2, 3], [4], r) == [8, 12] assert convolution_ntt([32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619, 459741727, 79180879, 831885249, 381344700, 369993322] assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \ [8158, 3065, 3682, 7090, 1239, 2232, 3744] assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \ convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q) assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \ convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q) raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r)) raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q)) raises(TypeError, lambda: convolution_ntt(x, y, p))
def test_convolution_ntt(): # prime moduli of the form (m*2**k + 1), sequence length # should be a divisor of 2**k p = 7 * 17 * 2**23 + 1 q = 19 * 2**10 + 1 r = 2 * 500000003 + 1 # only for sequences of length 1 or 2 s = 2 * 3 * 5 * 7 # composite modulus assert all( convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r)) assert convolution_ntt([2], [3], r) == [6] assert convolution_ntt([2, 3], [4], r) == [8, 12] assert convolution_ntt( [32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619, 459741727, 79180879, 831885249, 381344700, 369993322] assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \ [8158, 3065, 3682, 7090, 1239, 2232, 3744] assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \ convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q) assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \ convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q) raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r)) raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q)) raises(TypeError, lambda: convolution_ntt(x, y, p))