Exemplo n.º 1
0
 def test_correct_values(self):
     # shape = (4, 3, 2)
     input_array = np.array([[[1, 1], [2, 2], [3, 3]],
                             [[4, 4], [5, 5], [6, 6]],
                             [[7, 7], [8, 8], [9, 9]],
                             [[0, 0], [1, 1], [2, 2]]])
     output = utils.tohost(input_array)
     self.assertEqual(list(input_array.flatten()), list(output.flatten()))
Exemplo n.º 2
0
 def test_correct_shape(self):
     # shape = (4, 3, 2)
     input_array = np.array([[[1, 1], [2, 2], [3, 3]],
                             [[4, 4], [5, 5], [6, 6]],
                             [[7, 7], [8, 8], [9, 9]],
                             [[0, 0], [1, 1], [2, 2]]])
     output = utils.tohost(input_array)
     self.assertEqual(output.shape, (12, 2))
Exemplo n.º 3
0
 def test_error_if_not_enough_dims(self):
     array = np.ones([3])
     with self.assertRaisesRegex(ValueError,
                                 '(?i)not enough values.*got 1'):
         _ = utils.tohost(array)
Exemplo n.º 4
0
 def test_batch_dim_1(self):
     array = np.ones([1, 1, 4, 5])
     output = utils.tohost(array)
     self.assertEqual(output.shape, np.ones([1, 4, 5]).shape)
Exemplo n.º 5
0
 def test_flattens_batch_dim(self):
     array = np.ones([2, 3, 4, 5])
     output = utils.tohost(array)
     self.assertEqual(output.shape, np.ones([6, 4, 5]).shape)
Exemplo n.º 6
0
 def test_error_on_too_many_dims(self):
     input_array = np.array([1, 2, 3, 4, 5])
     with self.assertRaises(ValueError):
         _ = utils.tohost(input_array)