def test_diop_general_sum_of_squares_quick(): for i in range(3, 10): assert check_solutions(sum(i**2 for i in symbols(':%i' % i)) - i) assert diop_general_sum_of_squares(x**2 + y**2 - 2) is None assert diop_general_sum_of_squares(x**2 + y**2 + z**2 + 2) == set() eq = x**2 + y**2 + z**2 - (1 + 4 + 9) assert diop_general_sum_of_squares(eq) == \ {(1, 2, 3)} eq = u**2 + v**2 + x**2 + y**2 + z**2 - 1313 assert len(diop_general_sum_of_squares(eq, 3)) == 3 # issue 11016 var = symbols(':5') + (symbols('6', negative=True), ) eq = Add(*[i**2 for i in var]) - 112 base_soln = {(0, 1, 1, 5, 6, -7), (1, 1, 1, 3, 6, -8), (2, 3, 3, 4, 5, -7), (0, 1, 1, 1, 3, -10), (0, 0, 4, 4, 4, -8), (1, 2, 3, 3, 5, -8), (0, 1, 2, 3, 7, -7), (2, 2, 4, 4, 6, -6), (1, 1, 3, 4, 6, -7), (0, 2, 3, 3, 3, -9), (0, 0, 2, 2, 2, -10), (1, 1, 2, 3, 4, -9), (0, 1, 1, 2, 5, -9), (0, 0, 2, 6, 6, -6), (1, 3, 4, 5, 5, -6), (0, 2, 2, 2, 6, -8), (0, 3, 3, 3, 6, -7), (0, 2, 3, 5, 5, -7), (0, 1, 5, 5, 5, -6)} assert diophantine(eq) == base_soln assert len(diophantine(eq, permute=True)) == 196800 # handle negated squares with signsimp assert diophantine(12 - x**2 - y**2 - z**2) == {(2, 2, 2)} # diophantine handles simplification, so classify_diop should # not have to look for additional patterns that are removed # by diophantine eq = a**2 + b**2 + c**2 + d**2 - 4 raises(NotImplementedError, lambda: classify_diop(-eq))
def test_sum_of_squares_powers(): tru = set([ (0, 0, 1, 1, 11), (0, 0, 5, 7, 7), (0, 1, 3, 7, 8), (0, 1, 4, 5, 9), (0, 3, 4, 7, 7), (0, 3, 5, 5, 8), (1, 1, 2, 6, 9), (1, 1, 6, 6, 7), (1, 2, 3, 3, 10), (1, 3, 4, 4, 9), (1, 5, 5, 6, 6), (2, 2, 3, 5, 9), (2, 3, 5, 6, 7), (3, 3, 4, 5, 8)]) eq = u**2 + v**2 + x**2 + y**2 + z**2 - 123 ans = diop_general_sum_of_squares(eq, oo) # allow oo to be used assert len(ans) == 14 assert ans == tru raises(ValueError, lambda: list(sum_of_squares(10, -1))) assert list(sum_of_squares(-10, 2)) == [] assert list(sum_of_squares(2, 3)) == [] assert list(sum_of_squares(0, 3, True)) == [(0, 0, 0)] assert list(sum_of_squares(0, 3)) == [] assert list(sum_of_squares(4, 1)) == [(2,)] assert list(sum_of_squares(5, 1)) == [] assert list(sum_of_squares(50, 2)) == [(5, 5), (1, 7)] assert list(sum_of_squares(11, 5, True)) == [ (1, 1, 1, 2, 2), (0, 0, 1, 1, 3)] assert list(sum_of_squares(8, 8)) == [(1, 1, 1, 1, 1, 1, 1, 1)] assert [len(list(sum_of_squares(i, 5, True))) for i in range(30)] == [ 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 3, 2, 1, 3, 3, 3, 3, 4, 3, 3, 2, 2, 4, 4, 4, 4, 5] assert [len(list(sum_of_squares(i, 5))) for i in range(30)] == [ 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 3] for i in range(30): s1 = set(sum_of_squares(i, 5, True)) assert not s1 or all(sum(j**2 for j in t) == i for t in s1) s2 = set(sum_of_squares(i, 5)) assert all(sum(j**2 for j in t) == i for t in s2) raises(ValueError, lambda: list(sum_of_powers(2, -1, 1))) raises(ValueError, lambda: list(sum_of_powers(2, 1, -1))) assert list(sum_of_powers(-2, 3, 2)) == [(-1, -1)] assert list(sum_of_powers(-2, 4, 2)) == [] assert list(sum_of_powers(2, 1, 1)) == [(2,)] assert list(sum_of_powers(2, 1, 3, True)) == [(0, 0, 2), (0, 1, 1)] assert list(sum_of_powers(5, 1, 2, True)) == [(0, 5), (1, 4), (2, 3)] assert list(sum_of_powers(6, 2, 2)) == [] assert list(sum_of_powers(3**5, 3, 1)) == [] assert list(sum_of_powers(3**6, 3, 1)) == [(9,)] and (9**3 == 3**6) assert list(sum_of_powers(2**1000, 5, 2)) == []