def test_type_should_pass_mixed(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) asserts.assert_type([a_float, an_int, a_np_float, a_jax_int], [float, int, float, int])
def test_type_should_fail_unsupported_dtype(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) with self.assertRaisesRegex(AssertionError, 'unsupported dtype'): asserts.assert_type([a_float, an_int, a_np_float, a_jax_int], [np.complex, np.complex, float, int])
def test_type_should_fail_mixed(self): a_float = 1. an_int = 2 a_np_float = np.asarray([3., 4.]) a_jax_int = jnp.asarray([5, 6]) with self.assertRaisesRegex(AssertionError, 'input .+ has type .+ but expected .+'): asserts.assert_type([a_float, an_int, a_np_float, a_jax_int], [float, int, float, float])
def test_type_should_fail_wrong_length(self, array, wrong_type): with self.assertRaisesRegex( AssertionError, 'Length of `inputs` and `expected_types` must match'): asserts.assert_type(array, wrong_type)
def test_type_should_pass_array(self, array, wrong_type): array = self.variant(emplace)(array) asserts.assert_type(array, wrong_type)
def test_type_should_pass_scalar(self, array, wrong_type): asserts.assert_type(array, wrong_type)
def test_type_should_fail_array(self, array, wrong_type): array = self.variant(emplace)(array) with self.assertRaisesRegex(AssertionError, 'input .+ has type .+ but expected .+'): asserts.assert_type(array, wrong_type)
def test_type_should_fail_scalar(self, scalars, wrong_type): with self.assertRaisesRegex(AssertionError, 'input .+ has type .+ but expected .+'): asserts.assert_type(scalars, wrong_type)