コード例 #1
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_pass_mixed(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                       [float, int, float, int])
コード例 #2
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_fail_unsupported_dtype(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   with self.assertRaisesRegex(AssertionError, 'unsupported dtype'):
     asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                         [np.complex, np.complex, float, int])
コード例 #3
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_fail_mixed(self):
   a_float = 1.
   an_int = 2
   a_np_float = np.asarray([3., 4.])
   a_jax_int = jnp.asarray([5, 6])
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
                         [float, int, float, float])
コード例 #4
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_fail_wrong_length(self, array, wrong_type):
   with self.assertRaisesRegex(
       AssertionError, 'Length of `inputs` and `expected_types` must match'):
     asserts.assert_type(array, wrong_type)
コード例 #5
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_pass_array(self, array, wrong_type):
   array = self.variant(emplace)(array)
   asserts.assert_type(array, wrong_type)
コード例 #6
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_pass_scalar(self, array, wrong_type):
   asserts.assert_type(array, wrong_type)
コード例 #7
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_fail_array(self, array, wrong_type):
   array = self.variant(emplace)(array)
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type(array, wrong_type)
コード例 #8
0
ファイル: asserts_test.py プロジェクト: graingert/chex
 def test_type_should_fail_scalar(self, scalars, wrong_type):
   with self.assertRaisesRegex(AssertionError,
                               'input .+ has type .+ but expected .+'):
     asserts.assert_type(scalars, wrong_type)