def test_single_array(self): """ Test the concatenation of a single array """ obj = [np.arange(5)] * 5 concat = _concat(obj) self.assertEqual(concat.shape, (25, ))
def test_dict_1(self): """ Test list of dictionaries """ obj = [{'a': np.arange(5), 'b': np.ones(5) * 2}] * 5 concat = _concat(obj) self.assertEqual(concat['a'].shape, (25, )) self.assertEqual(concat['b'].shape, (25, )) for i in range(5): for j in range(5): self.assertTrue(concat['a'][i * 5 + j] == j) self.assertTrue((concat['b'] == 2).all())
def test_array_1(self): """ Test the concatenation of a [[[], []]] """ obj = [[np.arange(5), np.ones(5) * 2]] * 5 concat = _concat(obj) self.assertEqual(concat[0].shape, (25, )) self.assertEqual(concat[1].shape, (25, )) for i in range(5): for j in range(5): self.assertTrue(concat[0][i * 5 + j] == j) self.assertTrue((concat[1] == 2).all())
def test_non_concatenable_values2(self): obj = [{'a': (np.arange(5), np.ones(5), 2), 'b': 3, 'c': np.array(4)}] * 5 concat = _concat(obj) self.assertEqual(concat['a'][0].shape, (25, )) self.assertEqual(concat['a'][1].shape, (25, )) self.assertEqual(concat['a'][2], (2, ) * 5) self.assertEqual(concat['b'], [3] * 5) self.assertEqual(concat['c'], [4] * 5) for i in range(5): for j in range(5): self.assertTrue(concat['a'][0][i * 5 + j] == j) self.assertTrue((concat['a'][1] == 1).all())
def test_tuple_2(self): """ Test the concatenation of a [([], ([], []))] """ obj = [(np.arange(5), (np.ones(5) * 2, np.ones(5) * 3))] * 5 concat = _concat(obj) self.assertEqual(concat[0].shape, (25, )) self.assertEqual(concat[1][0].shape, (25, )) self.assertEqual(concat[1][1].shape, (25, )) for i in range(5): for j in range(5): self.assertTrue(concat[0][i * 5 + j] == j) self.assertTrue((concat[1][0] == 2).all()) self.assertTrue((concat[1][1] == 3).all())
def test_non_concatenable_values(self): obj = [3] * 5 concat = _concat(obj) self.assertEqual(concat, obj) self.assertEqual(type(concat), type(obj))