def test_dim_vars(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertEqual(True, a == a)
    self.assertEqual(False, a != a)
    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a.eq(b)

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a == b

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a != b

    self.assertLen({a, a}, 1)
    self.assertLen({a, b}, 2)
    self.assertIn(a, {a, b})
    self.assertIn(b, {a, b})
    self.assertIn(a, [a, b])
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison 'b' == 'a' is inconclusive"):
      b in [a, b]
  def test_dilate_shape(self):
    """0 if d == 0 else 1 + dilation * (d - 1))"""
    a, = shape_poly.parse_spec("a,", (2,))

    self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
    self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
    self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3)))
    self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3)))
 def test_poly_int_results(self):
   a, b = shape_poly.parse_spec("a, b", (2, 3))
   self.assertEqual(a + 2 - a, 2)
   self.assertIsInstance(a + 2 - a, int)
   self.assertEqual(a + (2 - a), 2)
   self.assertIsInstance(a + (2 - a), int)
   self.assertEqual(a * 2 // a, 2)
   self.assertIsInstance(a * 2 // a, int)
 def test_poly_bounds(self):
   a, b = shape_poly.parse_spec("a, b", (2, 3))
   self.assertEqual(a.bounds(), (1, None))
   self.assertEqual((2 * a).bounds(), (2, None))
   self.assertEqual((2 * a - 3).bounds(), (-1, None))
   self.assertEqual((-2 * a - 3).bounds(), (None, -5))
   self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None))
   self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None))
   self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None))
   self.assertEqual((a + 2 * b - a).bounds(), (2, None))
  def test_dilate_shape(self):
    da, = shape_poly.parse_spec("a,", (2,))

    self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
    self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
    self.assertEqual((da, 7), core.dilate_shape((da, 3), (1, 3)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Only dilation == 1 is supported for shape variables (var = a, dilation = 2)")):
      core.dilate_shape((da, 3), (2, 3))
 def test_parse_poly_spec(self):
   self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
   self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
   self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
   self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
   self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
   self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))
  def test_poly_compare(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly = 4 * a + b + 3
    self.assertTrue(poly.ge(0))
    self.assertTrue(poly.ge(8))
    self.assertTrue(poly.ge(poly))
    self.assertTrue(poly.ge(poly - 1))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      poly.ge(9)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      (4 * a - b).ge(0)
  def test_stride_shape(self):
    da, = shape_poly.parse_spec("a,", (2,))

    self.assertEqual((8, 9), core.stride_shape((10, 20), (3, 3), (1, 2)))
    self.assertEqual((da, 9), core.stride_shape((da, 20), (1, 3), (1, 2)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 2, stride = 1")):
      core.stride_shape((da, 20), (2, 3), (1, 2))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Only striding with window_size == window_stride == 1 is supported for shape variables (var = a, window_size = 1, stride = 2")):
      core.stride_shape((da, 20), (1, 3), (2, 2))
Exemple #9
0
    def test_dim_vars_symbolic_equal(self):
        da, db = shape_poly.parse_spec("a, b", (2, 3))
        self.assertTrue(core.symbolic_equal_dim(da, da))
        self.assertFalse(core.symbolic_equal_dim(da, 1))
        self.assertFalse(core.symbolic_equal_dim(da, db))

        self.assertTrue(core.symbolic_equal_one_of_dim(da, [2, da]))
        self.assertFalse(core.symbolic_equal_one_of_dim(da, [2, db]))
        self.assertFalse(core.symbolic_equal_one_of_dim(da, []))

        self.assertTrue(core.symbolic_equal_one_of_dim(2, [da, 3, 2]))
        self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, db]))
        self.assertFalse(core.symbolic_equal_one_of_dim(3, []))
  def test_poly_compare_overload(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly = 4 * a + b + 3
    self.assertTrue(poly >= 0)
    self.assertTrue(poly >= 8)
    self.assertTrue(poly > 7)
    self.assertTrue(poly >= poly)
    self.assertTrue(poly >= poly - 1)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      poly >= 9

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      (4 * a - b) >= 0
  def test_dim_vars_greater_equal(self):
    da, db = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.greater_equal_dim(da, da))
    self.assertTrue(core.greater_equal_dim(da, 0))
    self.assertTrue(core.greater_equal_dim(da, 1))

    self.assertTrue(core.greater_equal_shape((da, 2), (1, 1)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Shape variable comparison .* is inconclusive"):
      core.greater_equal_dim(da, 2)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Shape variable comparison .* is inconclusive"):
      core.greater_equal_dim(da, db)
  def test_core_greater_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.greater_equal_dim(a, a))
    self.assertTrue(core.greater_equal_dim(a, 0))
    self.assertTrue(core.greater_equal_dim(a, 1))

    self.assertTrue(core.greater_equal_shape((a, 2), (1, 1)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison .* is inconclusive"):
      core.greater_equal_dim(a, 2)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison .* is inconclusive"):
      core.greater_equal_dim(a, b)
  def test_stride_shape(self):
    """(s - window_size) // window_stride + 1"""
    a, stride = shape_poly.parse_spec("a, s", (2, 3))

    self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
    self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))

    self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2)))
    self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2)))

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        re.escape(
          "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")):
      core.stride_shape((a, 20), (1, 3), (2, 2))
  def test_dim_vars(self):
    """Unit tests for DimVar."""
    da, db = shape_poly.parse_spec("a, b", (2, 3))
    self.assertEqual(True, da == da)
    self.assertEqual(False, da != da)
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""):
      da == db
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""):
      da != db

    self.assertLen({da, da}, 1)
    self.assertLen({da, db}, 2)
    self.assertIn(da, {da, db})
    self.assertIn(db, {da, db})
    self.assertIn(da, [da, db])
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, ""):
      db in [da, db]
  def test_dim_vars_symbolic_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.symbolic_equal_dim(a, a))
    self.assertFalse(core.symbolic_equal_dim(a, 1))
    self.assertFalse(core.symbolic_equal_dim(a, b))

    self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, []))

    self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2]))
    self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(3, []))

    self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1)))  # A DeviceArray
    with self.assertRaisesRegex(TypeError,
                                re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")):
      self.assertTrue(core.symbolic_equal_dim(1, "a"))
  def test_poly_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly3 = a + 3 - a
    self.assertTrue(poly3 == 3)
    self.assertTrue(poly3 == np.array(3, np.int64))
    self.assertTrue(poly3 == np.array(3, np.int64)[()])
    self.assertFalse((poly3 + 1) == 3)
    self.assertFalse(poly3 == poly3 + 1)
    self.assertTrue((2 * a * b * a + 3).eq(1 + b * a * a + a * a * b + 2))
    self.assertFalse((2 * a * b * a + 3).eq(a * b * a + 3))

    self.assertFalse((a * b * a + 3).eq(a * b * a + 4))
    self.assertFalse((2 * a * b * a).eq(a * b * a))
    self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
    self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")):
      (3 * a * b * a - 2).eq(a * b * a)
  def test_evaluate(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))

    self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
    self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
  def test_get_vars(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))

    self.assertEqual({"a"}, a.get_vars())
    self.assertEqual({"a", "b"}, (a * b * a).get_vars())
class DimPolynomialTest(tf_test_util.JaxToTfTestCase):

  def test_parse_poly_spec(self):
    self.assertEqual((2, 3), shape_poly.parse_spec(None, (2, 3)))
    self.assertEqual((2, 3), shape_poly.parse_spec("2, 3", (2, 3)))
    self.assertEqual((2, 3), shape_poly.parse_spec("2, _", (2, 3)))
    self.assertEqual((2, 3), shape_poly.parse_spec("2, ...", (2, 3)))
    self.assertEqual((2, 3), shape_poly.parse_spec("...", (2, 3)))
    self.assertEqual((2, 3), shape_poly.parse_spec(" ( 2 , 3 ) ", (2, 3)))

  def test_dim_vars(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertEqual(True, a == a)
    self.assertEqual(False, a != a)
    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a.eq(b)

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a == b

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        "Dimension polynomial comparison 'a' == 'b' is inconclusive"):
      a != b

    self.assertLen({a, a}, 1)
    self.assertLen({a, b}, 2)
    self.assertIn(a, {a, b})
    self.assertIn(b, {a, b})
    self.assertIn(a, [a, b])
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison 'b' == 'a' is inconclusive"):
      b in [a, b]

  def test_get_vars(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))

    self.assertEqual({"a"}, a.get_vars())
    self.assertEqual({"a", "b"}, (a * b * a).get_vars())

  def test_evaluate(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))

    self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
    self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))

  def test_dim_vars_symbolic_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.symbolic_equal_dim(a, a))
    self.assertFalse(core.symbolic_equal_dim(a, 1))
    self.assertFalse(core.symbolic_equal_dim(a, b))

    self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(a, []))

    self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2]))
    self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b]))
    self.assertFalse(core.symbolic_equal_one_of_dim(3, []))

    self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1)))  # A DeviceArray
    with self.assertRaisesRegex(TypeError,
                                re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")):
      self.assertTrue(core.symbolic_equal_dim(1, "a"))

  def test_poly_bounds(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertEqual(a.bounds(), (1, None))
    self.assertEqual((2 * a).bounds(), (2, None))
    self.assertEqual((2 * a - 3).bounds(), (-1, None))
    self.assertEqual((-2 * a - 3).bounds(), (None, -5))
    self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None))
    self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None))
    self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None))
    self.assertEqual((a + 2 * b - a).bounds(), (2, None))

  def test_poly_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly3 = a + 3 - a
    self.assertTrue(poly3 == 3)
    self.assertTrue(poly3 == np.array(3, np.int64))
    self.assertTrue(poly3 == np.array(3, np.int64)[()])
    self.assertFalse((poly3 + 1) == 3)
    self.assertFalse(poly3 == poly3 + 1)
    self.assertTrue((2 * a * b * a + 3).eq(1 + b * a * a + a * a * b + 2))
    self.assertFalse((2 * a * b * a + 3).eq(a * b * a + 3))

    self.assertFalse((a * b * a + 3).eq(a * b * a + 4))
    self.assertFalse((2 * a * b * a).eq(a * b * a))
    self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
    self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                re.escape("Dimension polynomial comparison '3 a^2 b + -2' == 'a^2 b' is inconclusive")):
      (3 * a * b * a - 2).eq(a * b * a)

  def test_poly_compare(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly = 4 * a + b + 3
    self.assertTrue(poly.ge(0))
    self.assertTrue(poly.ge(8))
    self.assertTrue(poly.ge(poly))
    self.assertTrue(poly.ge(poly - 1))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      poly.ge(9)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      (4 * a - b).ge(0)

  def test_poly_compare_overload(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    poly = 4 * a + b + 3
    self.assertTrue(poly >= 0)
    self.assertTrue(poly >= 8)
    self.assertTrue(poly > 7)
    self.assertTrue(poly >= poly)
    self.assertTrue(poly >= poly - 1)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      poly >= 9

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
      (4 * a - b) >= 0

  def test_core_greater_equal(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertTrue(core.greater_equal_dim(a, a))
    self.assertTrue(core.greater_equal_dim(a, 0))
    self.assertTrue(core.greater_equal_dim(a, 1))

    self.assertTrue(core.greater_equal_shape((a, 2), (1, 1)))

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison .* is inconclusive"):
      core.greater_equal_dim(a, 2)

    with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                "Dimension polynomial comparison .* is inconclusive"):
      core.greater_equal_dim(a, b)

  def test_poly_int_results(self):
    a, b = shape_poly.parse_spec("a, b", (2, 3))
    self.assertEqual(a + 2 - a, 2)
    self.assertIsInstance(a + 2 - a, int)
    self.assertEqual(a + (2 - a), 2)
    self.assertIsInstance(a + (2 - a), int)
    self.assertEqual(a * 2 // a, 2)
    self.assertIsInstance(a * 2 // a, int)

  a, b = shape_poly.parse_spec("a, b", (2, 3))
  @parameterized.named_parameters(
      dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
           dividend=dividend, divisor=divisor, quotient=quotient,
           remainder=remainder)
      for dividend, divisor, quotient, remainder in [
          (a, 1, a, 0),
          (3 * a, 3, a, 0),
          (3 * a + 3, 3, a + 1, 0),
          (3 * a + 2, 3, a, 2),
          (3 * a + 5, 3, a + 1, 2),
          (3 * a - 2, 3, a - 1, 1),
          (3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0),
          (a * a - b * b, a + b, a - b, 0),
          (a, b, None, None),
          (3 * a, 2, None, None),
          (2 * a * b + b * b, a + b, None, None),
  ])
  def test_poly_divmod(self, dividend, quotient, divisor, remainder):
    if quotient is None:
      with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
                                  "Dimension polynomial .* is not a multiple of .*"):
        dividend.divmod(divisor)
    else:
      self.assertEqual((quotient, remainder), dividend.divmod(divisor))

  def test_dilate_shape(self):
    """0 if d == 0 else 1 + dilation * (d - 1))"""
    a, = shape_poly.parse_spec("a,", (2,))

    self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
    self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
    self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3)))
    self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3)))

  def test_stride_shape(self):
    """(s - window_size) // window_stride + 1"""
    a, stride = shape_poly.parse_spec("a, s", (2, 3))

    self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
    self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))

    self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2)))
    self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2)))

    with self.assertRaisesRegex(
        core.InconclusiveDimensionOperation,
        re.escape(
          "Cannot compute stride for dimension 'a', window_size '1', stride '2'. Reason: Dimension polynomial 'a + -1' is not a multiple of '2'")):
      core.stride_shape((a, 20), (1, 3), (2, 2))
 def shaped_array(shape_spec: str, actual_shape: core.Shape):
   return core.ShapedArray(
       shape_poly.parse_spec(shape_spec, actual_shape), np.float32)