def test_split_by_portion(self):
     left, right = split_numpy_arrays([np.arange(10)],
                                      portion=0.1,
                                      shuffle=False)
     np.testing.assert_equal(left, [np.arange(9)])
     np.testing.assert_equal(right, [[9]])
     left, right = split_numpy_arrays([np.arange(10)],
                                      portion=0.9,
                                      shuffle=False)
     np.testing.assert_equal(left, [[0]])
     np.testing.assert_equal(right, [np.arange(1, 10)])
 def test_shuffling_with_multiple_arrays(self):
     left, right = split_numpy_arrays(
         [np.arange(10), np.arange(10, 20)], size=1, shuffle=True)
     self.assertEqual(len(left[0]), 9)
     self.assertEqual(len(left[1]), 9)
     self.assertEqual(len(right[0]), 1)
     self.assertEqual(len(right[1]), 1)
     self.assertEqual(set(list(left[0]) + list(right[0])),
                      set(np.arange(10)))
     self.assertEqual(set(list(left[1]) + list(right[1])),
                      set(np.arange(10, 20)))
     np.testing.assert_equal(left[0] + 10, left[1])
     np.testing.assert_equal(right[0] + 10, right[1])
    def test_error_inputs(self):
        # test error inputs
        with self.assertRaisesRegex(
                ValueError, 'At least one of `portion` and `size` should '
                'be specified.'):
            split_numpy_arrays([])

        with self.assertRaisesRegex(
                ValueError, 'At least one of `portion` and `size` should '
                'be specified.'):
            split_numpy_arrays([np.arange(1)])

        with self.assertRaisesRegex(
                ValueError, 'The length of specified arrays are not equal.'):
            split_numpy_arrays([np.arange(10), np.arange(11)], portion=0.2)

        with self.assertRaisesRegex(ValueError,
                                    '`portion` must range from 0.0 to 1.0.'):
            split_numpy_arrays([np.arange(10)], portion=-0.1)

        with self.assertRaisesRegex(ValueError,
                                    '`portion` must range from 0.0 to 1.0.'):
            split_numpy_arrays([np.arange(10)], portion=1.1)
    def test_split_by_size(self):
        left, right = split_numpy_arrays([np.arange(10)],
                                         size=-1,
                                         shuffle=False)
        np.testing.assert_equal(left, [np.arange(10)])
        np.testing.assert_equal(right, [[]])

        left, right = split_numpy_arrays([np.arange(10)],
                                         size=0,
                                         shuffle=False)
        np.testing.assert_equal(left, [np.arange(10)])
        np.testing.assert_equal(right, [[]])

        left, right = split_numpy_arrays([np.arange(10)],
                                         size=1,
                                         shuffle=False)
        np.testing.assert_equal(left, [np.arange(9)])
        np.testing.assert_equal(right, [[9]])

        left, right = split_numpy_arrays([np.arange(10)],
                                         size=9,
                                         shuffle=False)
        np.testing.assert_equal(left, [[0]])
        np.testing.assert_equal(right, [np.arange(1, 10)])

        left, right = split_numpy_arrays([np.arange(10)],
                                         size=10,
                                         shuffle=False)
        np.testing.assert_equal(left, [[]])
        np.testing.assert_equal(right, [np.arange(10)])

        left, right = split_numpy_arrays([np.arange(10)],
                                         size=11,
                                         shuffle=False)
        np.testing.assert_equal(left, [[]])
        np.testing.assert_equal(right, [np.arange(10)])
 def test_empty_inputs(self):
     self.assertEqual(split_numpy_arrays([], portion=0.2), ((), ()))
     self.assertEqual(split_numpy_arrays([], size=10), ((), ()))
 def test_split_multi_dimensional_data(self):
     left, right = split_numpy_arrays([np.arange(24).reshape([6, 2, 2])],
                                      size=3,
                                      shuffle=False)
     np.testing.assert_equal(left, [np.arange(12).reshape([3, 2, 2])])
     np.testing.assert_equal(right, [np.arange(12, 24).reshape([3, 2, 2])])