Example #1
0
 def test_single_array(self):
     """
     Test the concatenation of a single array
     """
     obj = [np.arange(5)] * 5
     concat = _concat(obj)
     self.assertEqual(concat.shape, (25, ))
Example #2
0
 def test_dict_1(self):
     """
     Test list of dictionaries
     """
     obj = [{'a': np.arange(5), 'b': np.ones(5) * 2}] * 5
     concat = _concat(obj)
     self.assertEqual(concat['a'].shape, (25, ))
     self.assertEqual(concat['b'].shape, (25, ))
     for i in range(5):
         for j in range(5):
             self.assertTrue(concat['a'][i * 5 + j] == j)
     self.assertTrue((concat['b'] == 2).all())
Example #3
0
 def test_array_1(self):
     """
     Test the concatenation of a [[[], []]]
     """
     obj = [[np.arange(5), np.ones(5) * 2]] * 5
     concat = _concat(obj)
     self.assertEqual(concat[0].shape, (25, ))
     self.assertEqual(concat[1].shape, (25, ))
     for i in range(5):
         for j in range(5):
             self.assertTrue(concat[0][i * 5 + j] == j)
     self.assertTrue((concat[1] == 2).all())
Example #4
0
    def test_non_concatenable_values2(self):
        obj = [{'a': (np.arange(5), np.ones(5), 2), 'b': 3, 'c': np.array(4)}] * 5
        concat = _concat(obj)
        self.assertEqual(concat['a'][0].shape, (25, ))
        self.assertEqual(concat['a'][1].shape, (25, ))
        self.assertEqual(concat['a'][2], (2, ) * 5)
        self.assertEqual(concat['b'], [3] * 5)
        self.assertEqual(concat['c'], [4] * 5)

        for i in range(5):
            for j in range(5):
                self.assertTrue(concat['a'][0][i * 5 + j] == j)
        self.assertTrue((concat['a'][1] == 1).all())
Example #5
0
 def test_tuple_2(self):
     """
     Test the concatenation of a [([], ([], []))]
     """
     obj = [(np.arange(5), (np.ones(5) * 2, np.ones(5) * 3))] * 5
     concat = _concat(obj)
     self.assertEqual(concat[0].shape, (25, ))
     self.assertEqual(concat[1][0].shape, (25, ))
     self.assertEqual(concat[1][1].shape, (25, ))
     for i in range(5):
         for j in range(5):
             self.assertTrue(concat[0][i * 5 + j] == j)
     self.assertTrue((concat[1][0] == 2).all())
     self.assertTrue((concat[1][1] == 3).all())
Example #6
0
 def test_non_concatenable_values(self):
     obj = [3] * 5
     concat = _concat(obj)
     self.assertEqual(concat, obj)
     self.assertEqual(type(concat), type(obj))