def test_Poly_hash(self): assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1 assert (hash(Poly({ Mon(): 3, Mon({'n': 1}): 4 })) == hash(Poly({ Mon({'n': 1}): 4, Mon(): 3 })))
def test_Poly_equal(self): assert constant_poly(3) == 3 assert np.array(3, np.int64) == constant_poly(3) assert np.array(3, np.int64)[()] == constant_poly(3) assert not np.array(3, np.int64) != constant_poly(3) assert constant_poly(4) != 3 assert 3 == constant_poly(3) assert 4 != constant_poly(3) assert constant_poly(4) == constant_poly(4) assert constant_poly(3) != constant_poly(4) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3}) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4}) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2}) with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4}) with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4})
def test_Poly_equal(self): assert constant_poly(3) == 3 assert onp.array(3, onp.int64) == constant_poly(3) assert onp.array(3, onp.int64)[()] == constant_poly(3) assert not onp.array(3, onp.int64) != constant_poly(3) assert constant_poly(4) != 3 assert 3 == constant_poly(3) assert 4 != constant_poly(3) assert constant_poly(4) == constant_poly(4) assert constant_poly(3) != constant_poly(4) assert Poly({ Mon(): 3, Mon({'n': 1}): 4 }) == Poly({ Mon({'n': 1}): 4, Mon(): 3 }) assert Poly({ Mon(): 3, Mon({'n': 1}): 4 }) != Poly({ Mon(): 3, Mon({'n': 2}): 4 }) assert Poly({ Mon(): 3, Mon({'m': 1}): 4 }) != Poly({ Mon(): 3, Mon({'n': 1}): 4 })
def test_Poly_compare(self): poly = Poly({Mon(): 3, Mon({'n': 1}): 4}) # Assume poly > 0 to make various shape rules work with polymorphic shapes: assert poly >= 0 assert poly >= 1 assert poly > 0 assert 0 <= poly assert 0 < poly assert constant_poly(3) >= 1 assert constant_poly(3) > 1 self.assertRaisesRegex(ValueError, "", lambda: poly >= 2) self.assertRaisesRegex(ValueError, "", lambda: poly > 1)
def test_Poly_compare(self): poly = Poly({Mon(): 3, Mon({'n': 1}): 4}) # Assume poly > 0 to make various shape rules work with polymorphic shapes: assert poly >= 0 assert poly >= 1 assert poly > 0 assert 0 <= poly assert 0 < poly assert constant_poly(3) >= 1 assert constant_poly(3) > 1 assert poly >= poly assert poly >= poly - 1 assert poly < poly + 1 poly >= 3 poly > 2 with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): poly >= 4
class PolyTest(jtu.JaxTestCase): @parameterized.parameters([ ['(m, n)', 'ShapeSpec(m, n)'], ['(m * n)', 'ShapeSpec(m n)'], ['m * n', 'ShapeSpec(m n)'], ['(m * n,)', 'ShapeSpec(m n)'], ['(3, m)', 'ShapeSpec(3, m)'], ['(10, m)', 'ShapeSpec(10, m)'], ['(-10, m)', 'ShapeSpec(-10, m)'], ['(3 * m)', 'ShapeSpec(3 m)'], ['m', 'ShapeSpec(m)'], ['', 'ShapeSpec()'], ['n + -1*n', 'ShapeSpec(0)'], ['m + n', 'ShapeSpec(m + n)'], ['m + n * k', 'ShapeSpec(k n + m)'], ['m + 3 * k', 'ShapeSpec(3 k + m)'], ['-3 + k + k * k', 'ShapeSpec(k^2 + k + -3)'], ['', 'ShapeSpec()'], ['_', 'ShapeSpec(_)'], ]) def test_parse_spec(self, spec, ans): self.assertEqual(str(parse_spec(spec)), ans) self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans) def test_Poly_equal(self): assert constant_poly(3) == 3 assert np.array(3, np.int64) == constant_poly(3) assert np.array(3, np.int64)[()] == constant_poly(3) assert not np.array(3, np.int64) != constant_poly(3) assert constant_poly(4) != 3 assert 3 == constant_poly(3) assert 4 != constant_poly(3) assert constant_poly(4) == constant_poly(4) assert constant_poly(3) != constant_poly(4) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3}) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4}) assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2}) with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4}) with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4}) def test_Poly_hash(self): assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1 assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4})) == hash(Poly({Mon({'n': 1}): 4, Mon(): 3}))) def test_Mon_hash(self): assert not len(set(hash(Mon({'a': i})) for i in range(10))) == 1 assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1})) @parameterized.parameters([ (Mon({'a': 1}), Mon({'b': 1})), (Mon({'a': 2, 'b': 1}), Mon({'b': 1})), ]) def test_Mon_floordiv(self, divisor, quotient): dividend = quotient * divisor self.assertEqual(quotient, dividend // divisor) def test_Poly_compare(self): poly = Poly({Mon(): 3, Mon({'n': 1}): 4}) # Assume poly > 0 to make various shape rules work with polymorphic shapes: assert poly >= 0 assert poly >= 1 assert poly > 0 assert 0 <= poly assert 0 < poly assert constant_poly(3) >= 1 assert constant_poly(3) > 1 assert poly >= poly assert poly >= poly - 1 assert poly < poly + 1 poly >= 3 poly > 2 with self.assertRaisesRegex(UndefinedPoly, "inconclusive"): poly >= 4 n = Poly({Mon({'n': 1}): 1}) m = Poly({Mon({'m': 1}): 1}) must_divide_msg = " must divide size" @parameterized.parameters([ (1, constant_poly(0), 0), (n, 0, 0), (2, n, 1), (5, 2 * n, 0), (5, 2 * n + 4, 3), (n * n, n + 1, 0), (2 * n + 1, 2 * n + 1, n + 2, must_divide_msg), (n * m + 1, m + n + 1, n - 1, must_divide_msg), (n, n, 0), (n, n, 1, must_divide_msg), (n + 1, -n + 1, -1, must_divide_msg), ]) def test_Poly_divmod(self, divisor, quotient, remainder, error_message=None): dividend = quotient * divisor + remainder expected = (quotient, remainder) if dividend.is_constant: dividend = int(dividend) if error_message: with self.assertRaisesRegex(UndefinedPoly, error_message): divmod(dividend, divisor) else: self.assertEqual(expected, divmod(dividend, divisor)) def test_Poly_rsub(self): n = Poly({Mon({'n': 1}): 1}) assert -1 - n == -n - 1
def constant_poly(c): return Poly({Mon(): c})
def test_Poly_rsub(self): n = Poly({Mon({'n': 1}): 1}) assert -1 - n == -n - 1
def test_Poly_divmod(self): n = Poly({Mon({'n': 1}): 1}) assert (n, 1) == divmod(2 * n + 1, 2) assert (2 * n, 0) == divmod(10 * n, 5) assert (2 * n + 4, 3) == divmod(10 * n + 23, 5)
def test_slice_index_poly_start(self): n = Poly({Mon({'n': 1}): 1}) s = slice(n, None, None) assert (n, 2 * n, 1) == _polymorphic_slice_indices(s, 2 * n)