def test_has_shape_false_for_tuple(self): a = (2, 3) self.assertFalse(util.has_shape(a))
def test_has_shape_deals_gracefully_with_none(self): self.assertFalse(util.has_shape(None))
def test_has_shape_false_for_scalar(self): a = 1. self.assertFalse(util.has_shape(a))
def test_has_shape_true_for_complex_array(self): a = jnp.array([[1, 2], [3., 4.]]) self.assertTrue(util.has_shape(a))
def test_has_shape_true_for_single_element_numpy_array(self): a = np.array([1]) self.assertTrue(util.has_shape(a))