Esempio n. 1
0
 def test_Poly_hash(self):
     assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1
     assert (hash(Poly({
         Mon(): 3,
         Mon({'n': 1}): 4
     })) == hash(Poly({
         Mon({'n': 1}): 4,
         Mon(): 3
     })))
Esempio n. 2
0
 def test_Poly_equal(self):
   assert constant_poly(3) == 3
   assert np.array(3, np.int64) == constant_poly(3)
   assert np.array(3, np.int64)[()] == constant_poly(3)
   assert not np.array(3, np.int64) != constant_poly(3)
   assert constant_poly(4) != 3
   assert 3 == constant_poly(3)
   assert 4 != constant_poly(3)
   assert constant_poly(4) == constant_poly(4)
   assert constant_poly(3) != constant_poly(4)
   assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3})
   assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4})
   assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2})
   with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
     Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4})
   with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
     Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4})
Esempio n. 3
0
 def test_Poly_equal(self):
     assert constant_poly(3) == 3
     assert onp.array(3, onp.int64) == constant_poly(3)
     assert onp.array(3, onp.int64)[()] == constant_poly(3)
     assert not onp.array(3, onp.int64) != constant_poly(3)
     assert constant_poly(4) != 3
     assert 3 == constant_poly(3)
     assert 4 != constant_poly(3)
     assert constant_poly(4) == constant_poly(4)
     assert constant_poly(3) != constant_poly(4)
     assert Poly({
         Mon(): 3,
         Mon({'n': 1}): 4
     }) == Poly({
         Mon({'n': 1}): 4,
         Mon(): 3
     })
     assert Poly({
         Mon(): 3,
         Mon({'n': 1}): 4
     }) != Poly({
         Mon(): 3,
         Mon({'n': 2}): 4
     })
     assert Poly({
         Mon(): 3,
         Mon({'m': 1}): 4
     }) != Poly({
         Mon(): 3,
         Mon({'n': 1}): 4
     })
Esempio n. 4
0
    def test_Poly_compare(self):
        poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
        # Assume poly > 0 to make various shape rules work with polymorphic shapes:
        assert poly >= 0
        assert poly >= 1
        assert poly > 0

        assert 0 <= poly
        assert 0 < poly
        assert constant_poly(3) >= 1
        assert constant_poly(3) > 1
        self.assertRaisesRegex(ValueError, "", lambda: poly >= 2)
        self.assertRaisesRegex(ValueError, "", lambda: poly > 1)
Esempio n. 5
0
  def test_Poly_compare(self):
    poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
    # Assume poly > 0 to make various shape rules work with polymorphic shapes:
    assert poly >= 0
    assert poly >= 1
    assert poly > 0

    assert 0 <= poly
    assert 0 < poly
    assert constant_poly(3) >= 1
    assert constant_poly(3) > 1
    assert poly >= poly
    assert poly >= poly - 1
    assert poly < poly + 1

    poly >= 3
    poly > 2
    with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
      poly >= 4
Esempio n. 6
0
class PolyTest(jtu.JaxTestCase):

  @parameterized.parameters([
      ['(m, n)', 'ShapeSpec(m, n)'],
      ['(m * n)', 'ShapeSpec(m n)'],
      ['m * n', 'ShapeSpec(m n)'],
      ['(m * n,)', 'ShapeSpec(m n)'],
      ['(3, m)', 'ShapeSpec(3, m)'],
      ['(10, m)', 'ShapeSpec(10, m)'],
      ['(-10, m)', 'ShapeSpec(-10, m)'],
      ['(3 * m)', 'ShapeSpec(3 m)'],
      ['m', 'ShapeSpec(m)'],
      ['', 'ShapeSpec()'],
      ['n + -1*n', 'ShapeSpec(0)'],
      ['m + n', 'ShapeSpec(m + n)'],
      ['m + n * k', 'ShapeSpec(k n + m)'],
      ['m + 3 * k', 'ShapeSpec(3 k + m)'],
      ['-3 + k + k * k', 'ShapeSpec(k^2 + k + -3)'],
      ['', 'ShapeSpec()'],
      ['_', 'ShapeSpec(_)'],
  ])
  def test_parse_spec(self, spec, ans):
    self.assertEqual(str(parse_spec(spec)), ans)
    self.assertEqual(str(remap_ids(UniqueIds(), parse_spec(spec))), ans)

  def test_Poly_equal(self):
    assert constant_poly(3) == 3
    assert np.array(3, np.int64) == constant_poly(3)
    assert np.array(3, np.int64)[()] == constant_poly(3)
    assert not np.array(3, np.int64) != constant_poly(3)
    assert constant_poly(4) != 3
    assert 3 == constant_poly(3)
    assert 4 != constant_poly(3)
    assert constant_poly(4) == constant_poly(4)
    assert constant_poly(3) != constant_poly(4)
    assert Poly({Mon(): 3, Mon({'n': 1}): 4}) == Poly({Mon({'n': 1}): 4, Mon(): 3})
    assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 4, Mon({'n': 1}): 4})
    assert Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 2})
    with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
      Poly({Mon(): 3, Mon({'n': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 2}): 4})
    with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
      Poly({Mon(): 3, Mon({'m': 1}): 4}) != Poly({Mon(): 3, Mon({'n': 1}): 4})

  def test_Poly_hash(self):
    assert not len(set(hash(Poly({Mon(): i})) for i in range(10))) == 1
    assert (hash(Poly({Mon(): 3, Mon({'n': 1}): 4}))
            == hash(Poly({Mon({'n': 1}): 4, Mon(): 3})))

  def test_Mon_hash(self):
    assert not len(set(hash(Mon({'a': i})) for i in range(10))) == 1
    assert hash(Mon({'a': 1, 'b': 1})) == hash(Mon({'b': 1, 'a': 1}))

  @parameterized.parameters([
    (Mon({'a': 1}), Mon({'b': 1})),
    (Mon({'a': 2, 'b': 1}), Mon({'b': 1})),
  ])
  def test_Mon_floordiv(self, divisor, quotient):
    dividend = quotient * divisor
    self.assertEqual(quotient, dividend // divisor)

  def test_Poly_compare(self):
    poly = Poly({Mon(): 3, Mon({'n': 1}): 4})
    # Assume poly > 0 to make various shape rules work with polymorphic shapes:
    assert poly >= 0
    assert poly >= 1
    assert poly > 0

    assert 0 <= poly
    assert 0 < poly
    assert constant_poly(3) >= 1
    assert constant_poly(3) > 1
    assert poly >= poly
    assert poly >= poly - 1
    assert poly < poly + 1

    poly >= 3
    poly > 2
    with self.assertRaisesRegex(UndefinedPoly, "inconclusive"):
      poly >= 4

  n = Poly({Mon({'n': 1}): 1})
  m = Poly({Mon({'m': 1}): 1})

  must_divide_msg = " must divide size"

  @parameterized.parameters([
    (1, constant_poly(0), 0),
    (n, 0, 0),
    (2, n, 1),
    (5, 2 * n, 0),
    (5, 2 * n + 4, 3),
    (n * n, n + 1, 0),
    (2 * n + 1, 2 * n + 1, n + 2, must_divide_msg),
    (n * m + 1, m + n + 1, n - 1, must_divide_msg),
    (n, n, 0),
    (n, n, 1, must_divide_msg),
    (n + 1, -n + 1, -1, must_divide_msg),
  ])
  def test_Poly_divmod(self, divisor, quotient, remainder, error_message=None):
    dividend = quotient * divisor + remainder
    expected = (quotient, remainder)
    if dividend.is_constant: dividend = int(dividend)
    if error_message:
      with self.assertRaisesRegex(UndefinedPoly, error_message):
        divmod(dividend, divisor)
    else:
      self.assertEqual(expected, divmod(dividend, divisor))

  def test_Poly_rsub(self):
    n = Poly({Mon({'n': 1}): 1})
    assert -1 - n == -n - 1
Esempio n. 7
0
def constant_poly(c):
  return Poly({Mon(): c})
Esempio n. 8
0
 def test_Poly_rsub(self):
   n = Poly({Mon({'n': 1}): 1})
   assert -1 - n == -n - 1
Esempio n. 9
0
 def test_Poly_divmod(self):
     n = Poly({Mon({'n': 1}): 1})
     assert (n, 1) == divmod(2 * n + 1, 2)
     assert (2 * n, 0) == divmod(10 * n, 5)
     assert (2 * n + 4, 3) == divmod(10 * n + 23, 5)
Esempio n. 10
0
 def test_slice_index_poly_start(self):
     n = Poly({Mon({'n': 1}): 1})
     s = slice(n, None, None)
     assert (n, 2 * n, 1) == _polymorphic_slice_indices(s, 2 * n)