def test_dim_vars(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(True, a == a) self.assertEqual(False, a != a) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a.eq(b) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a == b with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a != b self.assertLen({a, a}, 1) self.assertLen({a, b}, 2) self.assertIn(a, {a, b}) self.assertIn(b, {a, b}) self.assertIn(a, [a, b]) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'b' == 'a' is inconclusive"): b in [a, b]
def test_dilate_shape(self): """0 if d == 0 else 1 + dilation * (d - 1))""" a, = shape_poly.parse_spec("a,", (2,)) self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3))) self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3))) self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3))) self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3)))
def test_poly_int_results(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(a + 2 - a, 2) self.assertIsInstance(a + 2 - a, int) self.assertEqual(a + (2 - a), 2) self.assertIsInstance(a + (2 - a), int) self.assertEqual(a * 2 // a, 2) self.assertIsInstance(a * 2 // a, int)
def test_poly_bounds(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(a.bounds(), (1, None)) self.assertEqual((2 * a).bounds(), (2, None)) self.assertEqual((2 * a - 3).bounds(), (-1, None)) self.assertEqual((-2 * a - 3).bounds(), (None, -5)) self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None)) self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None)) self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None)) self.assertEqual((a + 2 * b - a).bounds(), (2, None))
def test_dilate_shape(self): da, = shape_poly.parse_spec("a,", (2,)) self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3))) self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3))) self.assertEqual((da, 7), core.dilate_shape((da, 3), (1, 3))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Only dilation == 1 is supported for shape variables (var = a, dilation = 2)")): core.dilate_shape((da, 3), (2, 3))
def test_parse_poly_spec(self): self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
def test_poly_compare(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly = 4 * a + b + 3 self.assertTrue(poly.ge(0)) self.assertTrue(poly.ge(8)) self.assertTrue(poly.ge(poly)) self.assertTrue(poly.ge(poly - 1)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly.ge(9) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b).ge(0)
def test_stride_shape(self): da, = shape_poly.parse_spec("a,", (2,)) self.assertEqual((8, 9), core.stride_shape((10, 20), (3, 3), (1, 2))) self.assertEqual((da, 9), core.stride_shape((da, 20), (1, 3), (1, 2))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 2, stride = 1")): core.stride_shape((da, 20), (2, 3), (1, 2)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 1, stride = 2")): core.stride_shape((da, 20), (1, 3), (2, 2))
def test_dim_vars_symbolic_equal(self): da, db = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.symbolic_equal_dim(da, da)) self.assertFalse(core.symbolic_equal_dim(da, 1)) self.assertFalse(core.symbolic_equal_dim(da, db)) self.assertTrue(core.symbolic_equal_one_of_dim(da, [2, da])) self.assertFalse(core.symbolic_equal_one_of_dim(da, [2, db])) self.assertFalse(core.symbolic_equal_one_of_dim(da, [])) self.assertTrue(core.symbolic_equal_one_of_dim(2, [da, 3, 2])) self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, db])) self.assertFalse(core.symbolic_equal_one_of_dim(3, []))
def test_poly_compare_overload(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly = 4 * a + b + 3 self.assertTrue(poly >= 0) self.assertTrue(poly >= 8) self.assertTrue(poly > 7) self.assertTrue(poly >= poly) self.assertTrue(poly >= poly - 1) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly >= 9 with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b) >= 0
def test_dim_vars_greater_equal(self): da, db = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.greater_equal_dim(da, da)) self.assertTrue(core.greater_equal_dim(da, 0)) self.assertTrue(core.greater_equal_dim(da, 1)) self.assertTrue(core.greater_equal_shape((da, 2), (1, 1))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Shape variable comparison .* is inconclusive"): core.greater_equal_dim(da, 2) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Shape variable comparison .* is inconclusive"): core.greater_equal_dim(da, db)
def test_core_greater_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.greater_equal_dim(a, a)) self.assertTrue(core.greater_equal_dim(a, 0)) self.assertTrue(core.greater_equal_dim(a, 1)) self.assertTrue(core.greater_equal_shape((a, 2), (1, 1))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison .* is inconclusive"): core.greater_equal_dim(a, 2) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison .* is inconclusive"): core.greater_equal_dim(a, b)
def test_stride_shape(self): """(s - window_size) // window_stride + 1""" a, stride = shape_poly.parse_spec("a, s", (2, 3)) self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2))) self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2))) self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2))) self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2))) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, re.escape( "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")): core.stride_shape((a, 20), (1, 3), (2, 2))
def test_dim_vars(self): """Unit tests for DimVar.""" da, db = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(True, da == da) self.assertEqual(False, da != da) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""): da == db with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""): da != db self.assertLen({da, da}, 1) self.assertLen({da, db}, 2) self.assertIn(da, {da, db}) self.assertIn(db, {da, db}) self.assertIn(da, [da, db]) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""): db in [da, db]
def test_dim_vars_symbolic_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.symbolic_equal_dim(a, a)) self.assertFalse(core.symbolic_equal_dim(a, 1)) self.assertFalse(core.symbolic_equal_dim(a, b)) self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [])) self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2])) self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(3, [])) self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1))) # A DeviceArray with self.assertRaisesRegex(TypeError, re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")): self.assertTrue(core.symbolic_equal_dim(1, "a"))
def test_poly_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly3 = a + 3 - a self.assertTrue(poly3 == 3) self.assertTrue(poly3 == np.array(3, np.int64)) self.assertTrue(poly3 == np.array(3, np.int64)[()]) self.assertFalse((poly3 + 1) == 3) self.assertFalse(poly3 == poly3 + 1) self.assertTrue((2 * a * b * a + 3).eq(1 + b * a * a + a * a * b + 2)) self.assertFalse((2 * a * b * a + 3).eq(a * b * a + 3)) self.assertFalse((a * b * a + 3).eq(a * b * a + 4)) self.assertFalse((2 * a * b * a).eq(a * b * a)) self.assertFalse((2 * a * b * a + 1).eq(a * b * a)) self.assertFalse((3 * a * b * a - 1).eq(a * b * a)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")): (3 * a * b * a - 2).eq(a * b * a)
def test_evaluate(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3))) self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
def test_get_vars(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual({"a"}, a.get_vars()) self.assertEqual({"a", "b"}, (a * b * a).get_vars())
class DimPolynomialTest(tf_test_util.JaxToTfTestCase): def test_parse_poly_spec(self): self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3))) self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3))) def test_dim_vars(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(True, a == a) self.assertEqual(False, a != a) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a.eq(b) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a == b with self.assertRaisesRegex( core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'a' == 'b' is inconclusive"): a != b self.assertLen({a, a}, 1) self.assertLen({a, b}, 2) self.assertIn(a, {a, b}) self.assertIn(b, {a, b}) self.assertIn(a, [a, b]) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison 'b' == 'a' is inconclusive"): b in [a, b] def test_get_vars(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual({"a"}, a.get_vars()) self.assertEqual({"a", "b"}, (a * b * a).get_vars()) def test_evaluate(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3))) self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3))) def test_dim_vars_symbolic_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.symbolic_equal_dim(a, a)) self.assertFalse(core.symbolic_equal_dim(a, 1)) self.assertFalse(core.symbolic_equal_dim(a, b)) self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(a, [])) self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2])) self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b])) self.assertFalse(core.symbolic_equal_one_of_dim(3, [])) self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1))) # A DeviceArray with self.assertRaisesRegex(TypeError, re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")): self.assertTrue(core.symbolic_equal_dim(1, "a")) def test_poly_bounds(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(a.bounds(), (1, None)) self.assertEqual((2 * a).bounds(), (2, None)) self.assertEqual((2 * a - 3).bounds(), (-1, None)) self.assertEqual((-2 * a - 3).bounds(), (None, -5)) self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None)) self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None)) self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None)) self.assertEqual((a + 2 * b - a).bounds(), (2, None)) def test_poly_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly3 = a + 3 - a self.assertTrue(poly3 == 3) self.assertTrue(poly3 == np.array(3, np.int64)) self.assertTrue(poly3 == np.array(3, np.int64)[()]) self.assertFalse((poly3 + 1) == 3) self.assertFalse(poly3 == poly3 + 1) self.assertTrue((2 * a * b * a + 3).eq(1 + b * a * a + a * a * b + 2)) self.assertFalse((2 * a * b * a + 3).eq(a * b * a + 3)) self.assertFalse((a * b * a + 3).eq(a * b * a + 4)) self.assertFalse((2 * a * b * a).eq(a * b * a)) self.assertFalse((2 * a * b * a + 1).eq(a * b * a)) self.assertFalse((3 * a * b * a - 1).eq(a * b * a)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")): (3 * a * b * a - 2).eq(a * b * a) def test_poly_compare(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly = 4 * a + b + 3 self.assertTrue(poly.ge(0)) self.assertTrue(poly.ge(8)) self.assertTrue(poly.ge(poly)) self.assertTrue(poly.ge(poly - 1)) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly.ge(9) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b).ge(0) def test_poly_compare_overload(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) poly = 4 * a + b + 3 self.assertTrue(poly >= 0) self.assertTrue(poly >= 8) self.assertTrue(poly > 7) self.assertTrue(poly >= poly) self.assertTrue(poly >= poly - 1) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): poly >= 9 with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"): (4 * a - b) >= 0 def test_core_greater_equal(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertTrue(core.greater_equal_dim(a, a)) self.assertTrue(core.greater_equal_dim(a, 0)) self.assertTrue(core.greater_equal_dim(a, 1)) self.assertTrue(core.greater_equal_shape((a, 2), (1, 1))) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison .* is inconclusive"): core.greater_equal_dim(a, 2) with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial comparison .* is inconclusive"): core.greater_equal_dim(a, b) def test_poly_int_results(self): a, b = shape_poly.parse_spec("a, b", (2, 3)) self.assertEqual(a + 2 - a, 2) self.assertIsInstance(a + 2 - a, int) self.assertEqual(a + (2 - a), 2) self.assertIsInstance(a + (2 - a), int) self.assertEqual(a * 2 // a, 2) self.assertIsInstance(a * 2 // a, int) a, b = shape_poly.parse_spec("a, b", (2, 3)) @parameterized.named_parameters( dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}", dividend=dividend, divisor=divisor, quotient=quotient, remainder=remainder) for dividend, divisor, quotient, remainder in [ (a, 1, a, 0), (3 * a, 3, a, 0), (3 * a + 3, 3, a + 1, 0), (3 * a + 2, 3, a, 2), (3 * a + 5, 3, a + 1, 2), (3 * a - 2, 3, a - 1, 1), (3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0), (a * a - b * b, a + b, a - b, 0), (a, b, None, None), (3 * a, 2, None, None), (2 * a * b + b * b, a + b, None, None), ]) def test_poly_divmod(self, dividend, quotient, divisor, remainder): if quotient is None: with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "Dimension polynomial .* is not a multiple of .*"): dividend.divmod(divisor) else: self.assertEqual((quotient, remainder), dividend.divmod(divisor)) def test_dilate_shape(self): """0 if d == 0 else 1 + dilation * (d - 1))""" a, = shape_poly.parse_spec("a,", (2,)) self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3))) self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3))) self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3))) self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3))) def test_stride_shape(self): """(s - window_size) // window_stride + 1""" a, stride = shape_poly.parse_spec("a, s", (2, 3)) self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2))) self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2))) self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2))) self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2))) with self.assertRaisesRegex( core.InconclusiveDimensionOperation, re.escape( "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")): core.stride_shape((a, 20), (1, 3), (2, 2))
def shaped_array(shape_spec: str, actual_shape: core.Shape): return core.ShapedArray( shape_poly.parse_spec(shape_spec, actual_shape), np.float32)