コード例 #1
0
ファイル: test_base_execute.py プロジェクト: zzxx-husky/mars
    def testArraySplitExecution(self):
        x = arange(48, chunks=3).reshape(2, 3, 8)
        ss = array_split(x, 3, axis=2)

        res = [self.executor.execute_tensor(i, concat=True)[0] for i in ss]
        expected = np.array_split(np.arange(48).reshape(2, 3, 8), 3, axis=2)
        self.assertEqual(len(res), len(expected))
        [np.testing.assert_equal(r, e) for r, e in zip(res, expected)]

        ss = array_split(x, [3, 5, 6, 10], axis=2)

        res = [self.executor.execute_tensor(i, concat=True)[0] for i in ss]
        expected = np.array_split(np.arange(48).reshape(2, 3, 8), [3, 5, 6, 10], axis=2)
        self.assertEqual(len(res), len(expected))
        [np.testing.assert_equal(r, e) for r, e in zip(res, expected)]
コード例 #2
0
    def testArraySplit(self):
        a = arange(8, chunk_size=2)

        splits = array_split(a, 3)
        self.assertEqual(len(splits), 3)
        self.assertEqual([s.shape[0] for s in splits], [3, 3, 2])

        splits[0].tiles()
        self.assertEqual(splits[0].nsplits, ((2, 1), ))
        self.assertEqual(splits[1].nsplits, ((1, 2), ))
        self.assertEqual(splits[2].nsplits, ((2, ), ))

        a = arange(7, chunk_size=2)

        splits = array_split(a, 3)
        self.assertEqual(len(splits), 3)
        self.assertEqual([s.shape[0] for s in splits], [3, 2, 2])

        splits[0].tiles()
        self.assertEqual(splits[0].nsplits, ((2, 1), ))
        self.assertEqual(splits[1].nsplits, ((1, 1), ))
        self.assertEqual(splits[2].nsplits, ((1, 1), ))
コード例 #3
0
    def testArraySplit(self):
        a = arange(8, chunk_size=2)

        splits = array_split(a, 3)
        self.assertEqual(len(splits), 3)
        self.assertEqual([s.shape[0] for s in splits], [3, 3, 2])
        self.assertTrue(
            all(calc_shape(s) == ((3, ), (3, ), (2, )) for s in splits))

        splits[0].tiles()
        self.assertEqual(splits[0].nsplits, ((2, 1), ))
        self.assertEqual(splits[1].nsplits, ((1, 2), ))
        self.assertEqual(splits[2].nsplits, ((2, ), ))
        self.assertEqual(calc_shape(splits[0].chunks[0]),
                         splits[0].chunks[0].shape)
        self.assertEqual(calc_shape(splits[1].chunks[0]),
                         splits[1].chunks[0].shape)
        self.assertEqual(calc_shape(splits[2].chunks[0]),
                         splits[2].chunks[0].shape)

        a = arange(7, chunk_size=2)

        splits = array_split(a, 3)
        self.assertEqual(len(splits), 3)
        self.assertEqual([s.shape[0] for s in splits], [3, 2, 2])
        self.assertTrue(
            all(calc_shape(s) == ((3, ), (2, ), (2, )) for s in splits))

        splits[0].tiles()
        self.assertEqual(splits[0].nsplits, ((2, 1), ))
        self.assertEqual(splits[1].nsplits, ((1, 1), ))
        self.assertEqual(splits[2].nsplits, ((1, 1), ))
        self.assertEqual(calc_shape(splits[0].chunks[0]),
                         splits[0].chunks[0].shape)
        self.assertEqual(calc_shape(splits[1].chunks[0]),
                         splits[1].chunks[0].shape)
        self.assertEqual(calc_shape(splits[2].chunks[0]),
                         splits[2].chunks[0].shape)