コード例 #1
0
ファイル: test_util.py プロジェクト: byzhang/d3p
 def test_expand_shape_input_is_scalar_and_two_dims_expected(self):
     a = 3
     expected = (1, 1)
     res = util.expand_shape(a, 2)
     self.assertEqual(expected, res)
コード例 #2
0
ファイル: test_util.py プロジェクト: byzhang/d3p
 def test_expand_shape_input_has_two_missing_dims(self):
     a = jnp.array([[3, 4], [5, 6], [7, 8]])
     expected = (3, 2, 1, 1)
     res = util.expand_shape(a, 4)
     self.assertEqual(expected, res)
コード例 #3
0
ファイル: test_util.py プロジェクト: byzhang/d3p
 def test_expand_shape_input_is_scalar_and_one_dim_expected(self):
     a = 3
     expected = (1,)
     res = util.expand_shape(a, 1)
     self.assertEqual(expected, res)
コード例 #4
0
ファイル: test_util.py プロジェクト: byzhang/d3p
 def test_expand_shape_input_has_more_dims(self):
     a = jnp.array([[3, 4], [5, 6], [7, 8]])
     expected = (3, 2)
     res = util.expand_shape(a, 1)
     self.assertEqual(expected, res)