Ejemplo n.º 1
0
    def test_split_array_at_indices_wrong_dtype(self) -> None:
        """Tests if exception is thrown for wrong dtype."""
        # Prepare
        x = np.ones((100, 3, 32, 32), dtype=np.float32)
        split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int32)

        # Execute
        with self.assertRaises(ValueError):
            split_array_at_indices(x, split_idx)
Ejemplo n.º 2
0
    def test_split_array_at_indices_wrong_num_dims(self) -> None:
        """Tests if exception is thrown for wrong number of dimensions."""
        # Prepare
        x = np.ones((100, 3, 32, 32), dtype=np.float32)
        split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int64)
        split_idx = np.expand_dims(split_idx, axis=0)

        # Execute
        with self.assertRaises(ValueError):
            split_array_at_indices(x, split_idx)
Ejemplo n.º 3
0
    def test_split_array_at_indices_not_increasing(self) -> None:
        """Tests if exception is thrown for split not having increasing
        values."""
        # Prepare
        x = np.ones((100, 3, 32, 32), dtype=np.float32)
        split_idx = np.arange(start=0, stop=90, step=10, dtype=np.int64)
        split_idx[1] = 70

        # Execute
        with self.assertRaises(ValueError):
            split_array_at_indices(x, split_idx)
Ejemplo n.º 4
0
    def test_split_array(self) -> None:
        """Tests if split is correct."""
        # Prepare
        split_expected = [
            [
                np.zeros((3, 32, 32), dtype=np.float32),
                np.zeros((3, 32, 32), dtype=np.float32),
                np.zeros((3, 32, 32), dtype=np.float32),
                np.zeros((3, 32, 32), dtype=np.float32),
            ],
            [
                np.ones((3, 32, 32), dtype=np.float32),
                np.ones((3, 32, 32), dtype=np.float32),
                np.ones((3, 32, 32), dtype=np.float32),
                np.ones((3, 32, 32), dtype=np.float32),
            ],
            [
                2 * np.ones((3, 32, 32), dtype=np.float32),
                2 * np.ones((3, 32, 32), dtype=np.float32),
                2 * np.ones((3, 32, 32), dtype=np.float32),
                2 * np.ones((3, 32, 32), dtype=np.float32),
            ],
        ]

        x = np.concatenate(split_expected)
        split_idx = np.arange(start=0, stop=12, step=4, dtype=np.int64)

        # Execute
        list_splits = split_array_at_indices(x, split_idx)

        # Assert
        for idx, split in enumerate(list_splits):
            for idx_el, element in enumerate(split):
                np.testing.assert_equal(split_expected[idx][idx_el], element)