forked from ctz/bignum
/
gentests.py
131 lines (110 loc) · 3.96 KB
/
gentests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import random
import operator
import itertools
import math
import optparse
import sys
from dumbegcd import egcd
TESTS = 2
SIGNS = (1, -1)
SIZES = (16, 32, 64, 128, 192, 512, 1024, 2048, )
SHIFT_SIZES = range(1, 8)
EXP_SIZES = SIZES[:6]
def wordsz(n):
if n == 0:
return 1
return math.ceil(math.log(n, 2) / 32)
def random_carry(sz):
"""
Make a number which is sz bits long, and has some
32-bit words set or cleared to exercise carry/borrow
code paths.
"""
a = random.getrandbits(sz)
for _ in range(random.getrandbits(3)):
word = random.randrange(0, wordsz(a))
shift = word * 32
code = random.getrandbits(2)
if code == 0:
a |= 0xffffffff << shift
elif code == 1:
a ^= ((a >> shift) & 0xffffffff) << shift
elif code == 2:
a |= 0xff000000 << shift
elif code == 3:
a |= 0xf0000000 << shift
return a
def gen_tests(f, function, nargs, op, reject = lambda *x: False, sizesa = SIZES, sizesb = SIZES, sizesc = SIZES):
def gen(sizes, mkcandidate, signs):
for _ in range(TESTS):
first = True
candidates = [1] * nargs
while first or reject(*candidates):
for i in range(nargs):
candidates[i] = mkcandidate(sizes[i]) * signs[i]
first = False
args = ', '.join(str(x) for x in candidates)
print >>f, 'check("%s(%s) == %d");' % (function, args, op(*candidates))
sizes = itertools.product(*[sizesa, sizesb, sizesc][:nargs])
signs = itertools.product(*([SIGNS] * nargs))
for sg in signs:
for can in (random.getrandbits, random_carry):
for sz in sizes:
gen(sz, can, sg)
def gen_tests_with_file(fout, funcname, *args, **kwargs):
if fout is not None:
gen_tests(fout, funcname, *args, **kwargs)
else:
filename = 'test-%s.inc' % funcname
with open(filename, 'w') as f:
gen_tests(f, funcname, *args, **kwargs)
print filename, 'written.'
def gcd(a, b):
while b:
a, b = b, a % b
return a
def egcd_v(a, b): return egcd(a, b)[0]
def egcd_a(a, b): return egcd(a, b)[1]
def egcd_b(a, b): return egcd(a, b)[2]
def gcd_eq_zero(x, m):
# reject modinv tests which won't work
# because value isn't relatively prime
# to modulus
return gcd(x, m) != 1
def modinv(x, m):
gcd, a, b = egcd(x, m)
assert gcd == 1
if a < 0:
a += m
assert (a * x) % m == 1
return a
def trunc(a, b):
return a % (2 ** b)
def emit_tests(fout = None):
gen_tests_with_file(fout, 'mul', 2, operator.mul)
gen_tests_with_file(fout, 'add', 2, operator.add)
gen_tests_with_file(fout, 'sub', 2, operator.sub)
gen_tests_with_file(fout, 'sqr', 1, lambda x: operator.pow(x, 2))
gen_tests_with_file(fout, 'mod', 2, operator.mod, reject = lambda p, d: d == 0)
gen_tests_with_file(fout, 'div', 2, operator.div, reject = lambda p, d: d == 0)
gen_tests_with_file(fout, 'shl', 2, operator.ilshift, sizesb = SHIFT_SIZES)
gen_tests_with_file(fout, 'shr', 2, operator.irshift, sizesb = SHIFT_SIZES)
gen_tests_with_file(fout, 'trunc', 2, trunc, sizesb = SHIFT_SIZES)
gen_tests_with_file(fout, 'modmul', 3, lambda x, y, p: (x * y) % p)
gen_tests_with_file(fout, 'modexp', 3, pow, sizesb = EXP_SIZES)
gen_tests_with_file(fout, 'gcd', 2, gcd)
gen_tests_with_file(fout, 'egcd-v', 2, egcd_v)
gen_tests_with_file(fout, 'egcd-a', 2, egcd_a)
gen_tests_with_file(fout, 'egcd-b', 2, egcd_b)
gen_tests_with_file(fout, 'modinv', 2, modinv, reject = gcd_eq_zero)
if __name__ == '__main__':
op = optparse.OptionParser()
op.add_option('-c', '--continuous', action = 'store_true',
default = False,
help = 'Continuously generate tests to stdout.')
opts, _ = op.parse_args()
if opts.continuous:
while True:
emit_tests(sys.stdout)
else:
emit_tests()