def test_assert_equal_pass_on_arrays(self): # Not using named_parameters, becase JAX cannot be used before app.run(). asserts.assert_equal(jnp.ones([]), np.ones([])) asserts.assert_equal(jnp.ones([], dtype=jnp.int32), np.ones([], dtype=np.float64))
def test_assert_equal_fail(self, first, second): with self.assertRaises(AssertionError): asserts.assert_equal(first, second)
def test_assert_equal_pass(self, first, second): asserts.assert_equal(first, second)