def test_no_split_2(self) -> None:
        for data_matrix_transform in data_matrix_transforms:
            model = AdaptiveBayesianReticulum(prior=(1, 1), pruning_factor=2, random_state=666)

            Xy = np.array([
                [0.0, 0, 1],
                [0.0, 1, 0],
                [1.0, 2, 1],
                [1.0, 3, 0],
                [1.0, 4, 1],
            ])
            X = Xy[:, :-1]
            y = Xy[:, -1]

            X = data_matrix_transform(X)

            print(f'Testing {type(X).__name__}')
            model.fit(X, y)
            print(model)

            self.assertEqual(model.get_depth(), 0)
            self.assertEqual(model.get_n_leaves(), 0)
            self.assertTrue(model._is_fitted())
            self.assertIsNone(model.root_)

            self.assertEqual(model.predict([[0, 0]]), np.ones(1))
            assert_array_equal(model.predict_proba([[0, 0], [11, 99]]), [[0.4, 0.6], [0.4, 0.6]])
    def test_one_split(self) -> None:
        for data_matrix_transform in data_matrix_transforms:
            model = AdaptiveBayesianReticulum(
                prior=(1, 1),
                learning_rate_init=5e-2,
                initial_relative_stiffness=2,
                random_state=666)

            Xy = np.array([
                [0.0, 0, 0],
                [0.1, 1, 0],

                [0.9, 0, 1],
                [1.0, 1, 1],
            ])
            X = Xy[:, :-1]
            y = Xy[:, -1]

            X = data_matrix_transform(X)

            print(f'Testing {type(X).__name__}')
            model.fit(X, y)
            print(model)

            self.assertEqual(model.get_depth(), 1)
            self.assertEqual(model.get_n_leaves(), 2)

            self.assertIsNone(model.root_.left__child)
            self.assertIsNone(model.root_.right_child)

            x_axis_intersection = -model.root_.weights[0] / model.root_.weights[1]
            normal_slope = model.root_.weights[2] / model.root_.weights[1]
            self.assertTrue(0.4 < x_axis_intersection < 0.6)
            self.assertTrue(-0.15 < normal_slope < 0.15)

            expected = np.array([0, 0, 1, 1])
            self.assertEqual(model.predict([[0, 0]]), expected[0])
            self.assertEqual(model.predict([[0, 1]]), expected[1])
            self.assertEqual(model.predict([[1, 0]]), expected[2])
            self.assertEqual(model.predict([[1, 1]]), expected[3])

            for data_matrix_transform2 in data_matrix_transforms:
                assert_array_equal(model.predict(data_matrix_transform2([[0, 0], [0, 1], [1, 0], [1, 0]])), expected)

            expected = np.array([[3/4, 1/4], [3/4, 1/4], [1/4, 3/4], [1/4, 3/4]])
            assert_array_almost_equal(model.predict_proba([[0, 0]]), np.expand_dims(expected[0], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[0, 1]]), np.expand_dims(expected[1], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[1, 0]]), np.expand_dims(expected[2], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[1, 1]]), np.expand_dims(expected[3], 0), decimal=1)

            for data_matrix_transform2 in data_matrix_transforms:
                assert_array_almost_equal(model.predict_proba(data_matrix_transform2([[0, 0], [0, 1], [1, 0], [1, 0]])),
                                          expected, decimal=1)
    def test_two_splits(self) -> None:
        for data_matrix_transform in data_matrix_transforms:
            model = AdaptiveBayesianReticulum(
                prior=(1, 1),
                learning_rate_init=1e-1,
                n_gradient_descent_steps=1000,
                initial_relative_stiffness=2,
                random_state=666)

            Xy = np.array([
                [0.0, 0.0, 0],
                [0.0, 0.3, 0],
                [0.0, 0.7, 0],
                [0.0, 1.0, 0],

                [1.0, 0.1, 1],
                [1.0, 0.2, 1],
                [1.0, 0.8, 1],
                [1.0, 0.9, 1],

                [2.0, 0.4, 0],
                [2.0, 0.6, 0],
            ])
            X = Xy[:, :-1]
            y = Xy[:, -1]

            X = data_matrix_transform(X)

            print(f'Testing {type(X).__name__}')
            model.fit(X, y)
            print(model)

            self.assertEqual(model.get_depth(), 2)
            self.assertEqual(model.get_n_leaves(), 3)
            self.assertIsNone(model.root_.right_child)
            self.assertIsNotNone(model.root_.left__child)

            x_axis_intersection_0 = -model.root_.weights[0] / model.root_.weights[1]
            x_axis_intersection_1 = -model.root_.left__child.weights[0] / model.root_.left__child.weights[1]
            y_axis_intersection_0 = -model.root_.weights[0] / model.root_.weights[2]
            y_axis_intersection_1 = -model.root_.left__child.weights[0] / model.root_.left__child.weights[2]
            self.assertTrue(0.4 < x_axis_intersection_0 < 0.6, 'expected 1st split to cross x-axis around 0.5')
            self.assertTrue(1.4 < x_axis_intersection_1 < 1.6, 'expected 2nd split to cross x-axis around 1.5')
            self.assertTrue(abs(y_axis_intersection_0) > 15, 'expected 1st split to cross y-axis far away from 0')
            self.assertTrue(abs(y_axis_intersection_1) > 15, 'expected 2nd split to cross y-axis far away from 0')

            expected = np.array([0, 0, 1, 1, 0, 0])
            self.assertEqual(model.predict([[0, 0.5]]), expected[0])
            self.assertEqual(model.predict([[0.4, 0.5]]), expected[1])
            self.assertEqual(model.predict([[0.6, 0.5]]), expected[2])
            self.assertEqual(model.predict([[1.4, 0.5]]), expected[3])
            self.assertEqual(model.predict([[1.6, 0.5]]), expected[4])
            self.assertEqual(model.predict([[100, 0.5]]), expected[5])

            for data_matrix_transform2 in data_matrix_transforms:
                assert_array_equal(model.predict(data_matrix_transform2(
                    [[0.0, 0.5], [0.4, 0.5], [0.6, 0.5], [1.4, 0.5], [1.6, 0.5], [100, 0.5]])
                ), expected)

            expected = np.array([[5/6, 1/6], [5/6, 1/6], [1/6, 5/6], [1/6, 5/6], [3/4, 1/4], [3/4, 1/4]])
            assert_array_almost_equal(model.predict_proba([[0, 0.5]]), np.expand_dims(expected[0], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[0.4, 0.5]]), np.expand_dims(expected[1], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[0.6, 0.5]]), np.expand_dims(expected[2], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[1.4, 0.5]]), np.expand_dims(expected[3], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[1.6, 0.5]]), np.expand_dims(expected[4], 0), decimal=1)
            assert_array_almost_equal(model.predict_proba([[100, 0.5]]), np.expand_dims(expected[5], 0), decimal=1)

            for data_matrix_transform2 in data_matrix_transforms:
                assert_array_almost_equal(model.predict_proba(data_matrix_transform2(
                    [[0.0, 0.5], [0.4, 0.5], [0.6, 0.5], [1.4, 0.5], [1.6, 0.5], [100, 0.5]])
                ), expected, decimal=1)