Пример #1
0
 def test_type_should_pass_mixed(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                       [float, int, float, int])
Пример #2
0
 def test_type_should_fail_unsupported_dtype(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   with self.assertRaisesRegex(AssertionError, 'unsupported dtype'):
     asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                         [np.complex, np.complex, float, int])
Пример #3
0
 def test_type_should_fail_mixed(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                         [float, int, float, float])
Пример #4
0
 def test_type_should_fail_wrong_length(self, array, wrong_type):
   with self.assertRaisesRegex(
       AssertionError, 'Length of `inputs` and `expected_types` must match'):
     asserts.assert_type(array, wrong_type)
Пример #5
0
 def test_type_should_pass_array(self, array, wrong_type):
   array = self.variant(emplace)(array)
   asserts.assert_type(array, wrong_type)
Пример #6
0
 def test_type_should_pass_scalar(self, array, wrong_type):
   asserts.assert_type(array, wrong_type)
Пример #7
0
 def test_type_should_fail_array(self, array, wrong_type):
   array = self.variant(emplace)(array)
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type(array, wrong_type)
Пример #8
0
 def test_type_should_fail_scalar(self, scalars, wrong_type):
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type(scalars, wrong_type)