forked from xuanyan0x7c7/tchisla-solver.py
/
quadratic.py
326 lines (274 loc) · 10.6 KB
/
quadratic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import sys
import numbers
import operator
from functools import reduce
from gmpy2 import mpz, mpq as Fraction, is_square, isqrt
__all__ = ["Quadratic"]
primes = (2, 3, 5, 7)
def exp2(n):
return {1: 0, 3: 1, 7: 2, 15: 3}[n ^ (n - 1)]
_PyHASH_MODULUS = sys.hash_info.modulus
mpz_type = type(mpz())
mpq_type = type(Fraction())
class Quadratic(numbers.Real):
__slots__ = ("rational_part", "quadratic_power", "quadratic_part")
def __new__(cls, rational_part = Fraction(), quadratic_power = 0, quadratic_part = None):
self = super(Quadratic, cls).__new__(cls)
if type(rational_part) is str:
rational_part = Fraction(rational_part)
if isinstance(rational_part, (int, mpz_type, mpq_type)):
self.rational_part = Fraction(rational_part)
self.quadratic_power = quadratic_power
self.quadratic_part = quadratic_part
return self
elif isinstance(rational_part, Quadratic):
return rational_part
else:
raise NotImplementedError
def __str__(self):
if self.quadratic_part:
q = reduce(operator.mul, map(operator.pow, primes, self.quadratic_part))
quadratic_part_string = "s" * self.quadratic_power + "qrt(" + str(q) + ")"
if self.rational_part == 1:
return quadratic_part_string
elif self.rational_part == -1:
return "-" + quadratic_part_string
else:
return str(self.rational_part) + "*" + quadratic_part_string
else:
return str(self.rational_part)
__repr__ = __str__
def _operator_fallbacks(monomorphic_operator, fallback_operator):
def forward(a, b):
if isinstance(b, (int, mpz_type, mpq_type, Quadratic)):
return monomorphic_operator(a, Quadratic(b))
else:
return NotImplemented
def reverse(b, a):
if isinstance(a, (int, mpz_type, mpq_type, Quadratic)):
return monomorphic_operator(a, Quadratic(b))
else:
return NotImplemented
return forward, reverse
def _add(x, y):
if x.quadratic_power == y.quadratic_power and x.quadratic_part == y.quadratic_part:
if x.rational_part + y.rational_part == 0:
return Quadratic()
else:
return Quadratic(x.rational_part + y.rational_part, x.quadratic_power, x.quadratic_part)
__add__, __radd__ = _operator_fallbacks(_add, operator.add)
def _sub(x, y):
if x.quadratic_power == y.quadratic_power and x.quadratic_part == y.quadratic_part:
if x.rational_part == y.rational_part:
return Quadratic()
else:
return Quadratic(x.rational_part - y.rational_part, x.quadratic_power, x.quadratic_part)
__sub__, __rsub__ = _operator_fallbacks(_sub, operator.sub)
def _mul(x, y):
r = x.rational_part * y.rational_part
if x.quadratic_power == 0:
return Quadratic(r, y.quadratic_power, y.quadratic_part)
if y.quadratic_power == 0:
return Quadratic(r, x.quadratic_power, x.quadratic_part)
quadratic_power = max(x.quadratic_power, y.quadratic_power)
exp_quadratic_power = 1 << quadratic_power
shifts = quadratic_power - x.quadratic_power, quadratic_power - y.quadratic_power
prime_power_list = []
mask = 0
for prime, x_power, y_power in zip(primes, x.quadratic_part, y.quadratic_part):
power = (x_power << shifts[0]) + (y_power << shifts[1])
if power >= exp_quadratic_power:
r *= prime
power ^= exp_quadratic_power
prime_power_list.append(power)
mask |= power
if mask == 0:
return Quadratic(r)
mask_shift = exp2(mask)
return Quadratic(
r,
quadratic_power - mask_shift,
tuple(n >> mask_shift for n in prime_power_list)
)
__mul__, __rmul__ = _operator_fallbacks(_mul, operator.mul)
def _div(x, y):
r = x.rational_part / y.rational_part
if y.quadratic_power == 0:
return Quadratic(r, x.quadratic_power, x.quadratic_part)
quadratic_power = max(x.quadratic_power, y.quadratic_power)
exp_quadratic_power = 1 << quadratic_power
shifts = quadratic_power - x.quadratic_power, quadratic_power - y.quadratic_power
prime_power_list = []
mask = 0
for prime, x_power, y_power in zip(
primes,
x.quadratic_part or (0,) * len(primes),
y.quadratic_part
):
power = (x_power << shifts[0]) - (y_power << shifts[1])
if power < 0:
r /= prime
power += exp_quadratic_power
prime_power_list.append(power)
mask |= power
if mask == 0:
return Quadratic(r)
mask_shift = exp2(mask)
return Quadratic(
r,
quadratic_power - mask_shift,
tuple(n >> mask_shift for n in prime_power_list)
)
__truediv__, __rtruediv__ = _operator_fallbacks(_div, operator.truediv)
def __float__(x):
raise NotImplementedError
def __floordiv__(x, y):
raise NotImplementedError
def __rfloordiv__(x, y):
raise NotImplementedError
def __mod__(x, y):
raise NotImplementedError
def __rmod__(x, y):
raise NotImplementedError
@staticmethod
def square(x):
r = x.rational_part ** 2
p = x.quadratic_power
if p == 0:
return Quadratic(r)
elif p == 1:
for prime, power in zip(primes, x.quadratic_part):
if power:
r *= prime
return Quadratic(r)
power_mask = 1 << (p - 1)
prime_power_list = []
for prime, power in zip(primes, x.quadratic_part):
if power >= power_mask:
r *= prime
prime_power_list.append(power ^ power_mask)
else:
prime_power_list.append(power)
return Quadratic(r, p - 1, tuple(prime_power_list))
@staticmethod
def inverse(x):
r = x.rational_part ** -1
q = x.quadratic_part
if x.quadratic_power == 0:
return Quadratic(r)
prime_base = 1
prime_power_list = []
for prime, power in zip(primes, q):
if power:
r /= prime
prime_power_list.append((1 << x.quadratic_power) - power)
else:
prime_power_list.append(0)
return Quadratic(r, x.quadratic_power, tuple(prime_power_list))
def __pow__(x, y):
power = None
inverse = False
if isinstance(y, (int, mpz_type)):
power = int(abs(y))
inverse = y < 0
elif y.quadratic_power == 0 and y.rational_part.denominator == 1:
power = int(abs(y.rational_part.numerator))
inverse = y.rational_part.numerator < 0
else:
raise NotImplementedError
if power == 0:
return Quadratic(1)
r = x.rational_part ** power
quadratic_power = x.quadratic_power
if quadratic_power == 0:
return Quadratic(r ** -1) if inverse else Quadratic(r)
while quadratic_power and power & 1 == 0:
quadratic_power -= 1
power >>= 1
exp_quadratic_power_m1 = (1 << quadratic_power) - 1
prime_power_list = []
for prime, x_power in zip(primes, x.quadratic_part):
p = x_power * power
r *= prime ** (p >> quadratic_power)
prime_power_list.append(p & exp_quadratic_power_m1)
result = None
if quadratic_power == 0:
result = Quadratic(r)
else:
result = Quadratic(r, quadratic_power, tuple(prime_power_list))
return Quadratic.inverse(result) if inverse else result
def __rpow__(x, y):
raise NotImplementedError
@staticmethod
def sqrt(x):
p = x.rational_part
s, t = p.numerator, p.denominator
if p == 0:
return Quadratic()
elif p < 0:
raise NotImplementedError
elif x.quadratic_power == 0:
if is_square(s) and is_square(t):
return Quadratic(Fraction(isqrt(s), isqrt(t)))
r = Fraction(1, t)
p = s * t
prime_power_list = []
for prime, x_power in zip(primes, x.quadratic_part or (0,) * len(primes)):
power = 0
while p % prime == 0:
p //= prime
power += 1
r *= prime ** (power >> 1)
prime_power_list.append(((power & 1) << x.quadratic_power) | x_power)
if not is_square(p):
return
return Quadratic(r * isqrt(p), x.quadratic_power + 1, tuple(prime_power_list))
def __pos__(x):
return Quadratic(x.rational_part, x.quadratic_power, x.quadratic_part)
def __neg__(x):
return Quadratic(-x.rational_part, x.quadratic_power, x.quadratic_part)
def __abs__(x):
return Quadratic(abs(x.rational_part), x.quadratic_power, x.quadratic_part)
def __int__(x):
if x.quadratic_power == 0 and x.rational_part.denominator == 1:
return int(x.rational_part.numerator)
def __trunc__(x):
raise NotImplementedError
def __floor__(x):
raise NotImplementedError
def __ceil__(x):
raise NotImplementedError
def __round__(x):
raise NotImplementedError
def __hash__(self):
if self.quadratic_power == 0:
return hash(self.rational_part)
else:
return hash(self.rational_part) * hash(self.quadratic_power) * hash(self.quadratic_part) % _PyHASH_MODULUS
def __eq__(x, y):
if isinstance(y, (int, mpz_type, mpq_type)):
return x.quadratic_power == 0 and x.rational_part == y
elif isinstance(y, Quadratic):
return x.rational_part == y.rational_part and x.quadratic_power == y.quadratic_power and x.quadratic_part == y.quadratic_part
else:
return NotImplemented
def __lt__(x, y):
raise NotImplementedError
def __gt__(x, y):
raise NotImplementedError
def __le__(x, y):
raise NotImplementedError
def __ge__(x, y):
raise NotImplementedError
def __bool__(x):
return x.rational_part != 0
def __reduce__(self):
return (self.__class__, (str(self),))
def __copy__(self):
if type(self) is Quadratic:
return self
return self.__class__(self)
def __deepcopy__(self):
if type(self) is Quadratic:
return self
return self.__class__(self)