def general_polynomial(self, poly): # returns a Polynomial P such that P(x) == sum(poly(i) for i in range(1, x + 1)) p = Polynomial([], modulus=self._modulus, primitive_root=self._primitive_root) for i, c in enumerate(poly._coef): if c: p.add_polynomial(self.base_polynomial(i, c)) return p
def __init__(self, height, modulus, primitive_root, max_width=100): super(TallCastleSolver, self).__init__(modulus, primitive_root) self._height = height % modulus self._faulhaber = Faulhaber(max_width, modulus, primitive_root) self._height_parity = height % 2 self._zero_poly = Polynomial([0], self._modulus, self._primitive_root) self._one_poly = Polynomial([1], self._modulus, self._primitive_root) self._tall_solutions = defaultdict(lambda: None) self._short_solutions = defaultdict(lambda: None) self._all_tall_solutions = defaultdict(lambda: None) self._all_short_solutions = defaultdict(lambda: None)
def base_polynomial(self, p, scalar=1): # returns a Polynomial P (of degree p + 1) such that P(x) == scalar * sum(i^p for i in range(1, x + 1)) if p + 1 > self._max_degree: raise ValueError # compute p(x) = sum(i^p for i in range(1, n + 1)) coef = list(repeat(0, p + 2)) for i in range(p + 1): coef[p + 1 - i] = self.mult(self._choose.choose(p + 1, i), self._bn[i]) poly = Polynomial(coef, modulus=self._modulus, primitive_root=self._primitive_root) poly.scalarmult(self.div(scalar, p + 1)) return poly
class TallCastleSolver(ModularSolver): def __init__(self, height, modulus, primitive_root, max_width=100): super(TallCastleSolver, self).__init__(modulus, primitive_root) self._height = height % modulus self._faulhaber = Faulhaber(max_width, modulus, primitive_root) self._height_parity = height % 2 self._zero_poly = Polynomial([0], self._modulus, self._primitive_root) self._one_poly = Polynomial([1], self._modulus, self._primitive_root) self._tall_solutions = defaultdict(lambda: None) self._short_solutions = defaultdict(lambda: None) self._all_tall_solutions = defaultdict(lambda: None) self._all_short_solutions = defaultdict(lambda: None) def tall_solutions(self, width, parity, rh_parity): if not self._tall_solutions[(width, parity, rh_parity)]: self._tall_solutions[(width, parity, rh_parity)] = self.compute_tall_solutions(width, parity, rh_parity) return self._tall_solutions[(width, parity, rh_parity)] def compute_tall_solutions(self, width, parity, rh_parity): if width == 1: total = AlternatingPolynomial(self._zero_poly.copy(), self._zero_poly.copy(), self._modulus, self._primitive_root) return total elif width < 1: raise NotImplementedError # rh <= new row height new_row_taller = self.tall_solutions(width - 1, 0, (parity + rh_parity) % 2).copy() new_row_taller.add_ap(self.tall_solutions(width - 1, 1, (parity + rh_parity + 1) % 2)) new_row_taller = self.sum_alternating_polynomial(new_row_taller) # rh > new row height new_row_shorter = self.tall_solutions(width - 1, parity, 0).copy() new_row_shorter.add_ap(self.tall_solutions(width - 1, parity, 1)) new_row_shorter = self.sum_alternating_polynomial(new_row_shorter) new_row_taller.sub_ap(new_row_shorter) total = new_row_taller valid_subsolutions_e = self.all_tall_solutions(width - 1, parity, 0) valid_subsolutions_o = self.all_tall_solutions(width - 1, parity, 1) total.add_constant(self.add(valid_subsolutions_e, valid_subsolutions_o)) if rh_parity == 1: total._even_poly = self._zero_poly.copy() else: total._odd_poly = self._zero_poly.copy() return total def short_solutions(self, width, parity, rh_parity): if not self._short_solutions[(width, parity, rh_parity)]: self._short_solutions[(width, parity, rh_parity)] = self.compute_short_solutions(width, parity, rh_parity) return self._short_solutions[(width, parity, rh_parity)] def compute_short_solutions(self, width, parity, rh_parity): if width == 1: if parity == rh_parity: if parity == 0: total = AlternatingPolynomial(self._one_poly.copy(), self._zero_poly.copy(), self._modulus, self._primitive_root) else: total = AlternatingPolynomial(self._zero_poly.copy(), self._one_poly.copy(), self._modulus, self._primitive_root) return total else: total = AlternatingPolynomial(self._zero_poly.copy(), self._zero_poly.copy(), self._modulus, self._primitive_root) return total elif width < 1: raise NotImplementedError # rh <= new row height new_row_taller = self.short_solutions(width - 1, 0, (parity + rh_parity) % 2).copy() new_row_taller.add_ap(self.short_solutions(width - 1, 1, (parity + rh_parity + 1) % 2)) new_row_taller = self.sum_alternating_polynomial(new_row_taller) # rh > new row height new_row_shorter = self.short_solutions(width - 1, parity, 0).copy() new_row_shorter.add_ap(self.short_solutions(width - 1, parity, 1)) new_row_shorter = self.sum_alternating_polynomial(new_row_shorter) new_row_taller.sub_ap(new_row_shorter) total = new_row_taller short_subsolutions_e = self.all_short_solutions(width - 1, parity, 0) short_subsolutions_o = self.all_short_solutions(width - 1, parity, 1) total.add_constant(self.add(short_subsolutions_e, short_subsolutions_o)) if rh_parity == 1: total._even_poly = self._zero_poly.copy() else: total._odd_poly = self._zero_poly.copy() return total def all_tall_solutions(self, width, parity, rh_parity): if self._all_tall_solutions[(width, parity, rh_parity)] is None: self._all_tall_solutions[(width, parity, rh_parity)] = self.compute_all_tall_solutions(width, parity, rh_parity) return self._all_tall_solutions[(width, parity, rh_parity)] def compute_all_tall_solutions(self, width, parity, rh_parity): # add short subsolutions short_subsolutions = 0 if rh_parity == self._height_parity and width > 1: short_subsolutions = self.all_short_solutions(width - 1, 0, (parity + rh_parity) % 2) short_subsolutions = self.add(short_subsolutions, self.all_short_solutions(width - 1, 1, (parity + rh_parity + 1) % 2)) if width == 1 and parity == rh_parity and parity == self._height_parity: short_subsolutions = self.add(short_subsolutions, 1) return self.add(self.sum_alternating_polynomial(self.tall_solutions(width, parity, rh_parity)).solve(self._height, parity=self._height_parity), short_subsolutions) def all_short_solutions(self, width, parity, rh_parity): if self._all_short_solutions[(width, parity, rh_parity)] is None: self._all_short_solutions[(width, parity, rh_parity)] = self.compute_all_short_solutions(width, parity, rh_parity) return self._all_short_solutions[(width, parity, rh_parity)] def compute_all_short_solutions(self, width, parity, rh_parity): return self.sum_alternating_polynomial(self.short_solutions(width, parity, rh_parity)).solve(self.subtract(self._height, 1), parity=(self._height_parity + 1) % 2) def sum_alternating_polynomial(self, ap): sum_even = self._faulhaber.general_polynomial(ap._even_poly) sum_odd = self._faulhaber.general_polynomial(ap._odd_poly) new_even = sum_even.copy() new_even.add_polynomial(sum_odd) new_odd = new_even.copy() new_odd.sub_polynomial(ap._even_poly) result = AlternatingPolynomial(new_even, new_odd, self._modulus, self._primitive_root) return result def sum_minus_one_ap(self, ap): sum_even = self._faulhaber.general_polynomial(ap._even_poly) sum_odd = self._faulhaber.general_polynomial(ap._odd_poly) new_even = sum_even.copy() new_even.add_polynomial(sum_odd) new_even.sub_polynomial(ap._even_poly) new_odd = new_even.copy() new_odd.sub_polynomial(ap._odd_poly) result = AlternatingPolynomial(new_even, new_odd, self._modulus, self._primitive_root) return result def validate(self, width): all_castles = self.pow(self._height, width) even_tall = self.add(self.all_tall_solutions(width, 0, 0), self.all_tall_solutions(width, 0, 1)) odd_tall = self.add(self.all_tall_solutions(width, 1, 0), self.all_tall_solutions(width, 1, 1)) even_short = self.add(self.all_short_solutions(width, 0, 0), self.all_short_solutions(width, 0, 1)) odd_short = self.add(self.all_short_solutions(width, 1, 0), self.all_short_solutions(width, 1, 1)) print "{even_tall} + {odd_tall} + {even_short} + {odd_short} == {all_castles}".format(**locals()) assert (even_tall + odd_tall + even_short + odd_short) % self._modulus == all_castles def F(self, width): total = 0 for r in (0, 1): total = self.add(total, self.all_tall_solutions(width, 0, r)) return total