def test_split_array_at_indices_wrong_dtype(self) -> None: """Tests if exception is thrown for wrong dtype.""" # Prepare x = np.ones((100, 3, 32, 32), dtype=np.float32) split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int32) # Execute with self.assertRaises(ValueError): split_array_at_indices(x, split_idx)
def test_split_array_at_indices_wrong_num_dims(self) -> None: """Tests if exception is thrown for wrong number of dimensions.""" # Prepare x = np.ones((100, 3, 32, 32), dtype=np.float32) split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int64) split_idx = np.expand_dims(split_idx, axis=0) # Execute with self.assertRaises(ValueError): split_array_at_indices(x, split_idx)
def test_split_array_at_indices_not_increasing(self) -> None: """Tests if exception is thrown for split not having increasing values.""" # Prepare x = np.ones((100, 3, 32, 32), dtype=np.float32) split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int64) split_idx[1] = 70 # Execute with self.assertRaises(ValueError): split_array_at_indices(x, split_idx)
def test_split_array(self) -> None: """Tests if split is correct.""" # Prepare split_expected = [ [ np.zeros((3, 32, 32), dtype=np.float32), np.zeros((3, 32, 32), dtype=np.float32), np.zeros((3, 32, 32), dtype=np.float32), np.zeros((3, 32, 32), dtype=np.float32), ], [ np.ones((3, 32, 32), dtype=np.float32), np.ones((3, 32, 32), dtype=np.float32), np.ones((3, 32, 32), dtype=np.float32), np.ones((3, 32, 32), dtype=np.float32), ], [ 2 * np.ones((3, 32, 32), dtype=np.float32), 2 * np.ones((3, 32, 32), dtype=np.float32), 2 * np.ones((3, 32, 32), dtype=np.float32), 2 * np.ones((3, 32, 32), dtype=np.float32), ], ] x = np.concatenate(split_expected) split_idx = np.arange(start=0, stop=12, step=4, dtype=np.int64) # Execute list_splits = split_array_at_indices(x, split_idx) # Assert for idx, split in enumerate(list_splits): for idx_el, element in enumerate(split): np.testing.assert_equal(split_expected[idx][idx_el], element)