コード例 #1
0
 def test_predict_proba_distance_kernel(self):
     """Test the 'predict_proba' class method using 'distance' kernel."""
     training_dataset = gs.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0],
                                  [3.0, 0.0]])
     labels = [0, 0, 1, 1]
     kde = KernelDensityEstimationClassifier(kernel="distance",
                                             distance=self.distance)
     kde.fit(training_dataset, labels)
     result = kde.predict_proba(gs.array([[1.0, 0.0]]))
     expected = gs.array([[1, 0]])
     self.assertAllClose(expected, result, atol=gs.atol)
コード例 #2
0
 def test_predict_proba_triangular_kernel_callable_distance(self):
     """Test the 'predict_proba' class method using a triangular kernel."""
     training_dataset = gs.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0],
                                  [3.0, 0.0]])
     labels = [0, 0, 1, 1]
     kde = KernelDensityEstimationClassifier(
         kernel=triangular_radial_kernel,
         bandwidth=2.0,
         distance=self.distance)
     kde.fit(training_dataset, labels)
     result = kde.predict_proba(gs.array([[1.0, 0.0]]))
     expected = gs.array([[3 / 4, 1 / 4]])
     self.assertAllClose(expected, result, atol=gs.atol)
コード例 #3
0
    def test_predict_proba_uniform_kernel_one_dimensional_data(self):
        """Test the 'predict_proba' class method using the 'uniform' kernel.

        Test the 'predict_proba' class method using the 'uniform' kernel on
        one-dimensional date of shape [n_samples,].
        """
        training_dataset = gs.array([0, 1, 2, 3])
        labels = [0, 0, 1, 1]
        kde = KernelDensityEstimationClassifier(kernel="uniform",
                                                distance=self.distance)
        kde.fit(training_dataset, labels)
        result = kde.predict_proba(gs.array([0.9]))
        expected = gs.array([[1 / 2, 1 / 2]])
        self.assertAllClose(expected, result, atol=gs.atol)
コード例 #4
0
def main():
    """Plot a Kernel Density Estimation Classification on the sphere."""
    sphere = Hypersphere(dim=2)
    sphere_distance = sphere.metric.dist

    n_labels = 2
    n_samples_per_dataset = 10
    n_targets = 200
    radius = np.inf

    kernel = triangular_radial_kernel
    bandwidth = 3

    n_training_samples = n_labels * n_samples_per_dataset
    dataset_1 = sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    dataset_2 = - sphere.random_von_mises_fisher(
        kappa=10,
        n_samples=n_samples_per_dataset)
    training_dataset = gs.concatenate((dataset_1, dataset_2), axis=0)
    labels_dataset_1 = gs.zeros([n_samples_per_dataset], dtype=gs.int64)
    labels_dataset_2 = gs.ones([n_samples_per_dataset], dtype=gs.int64)
    labels = gs.concatenate((labels_dataset_1, labels_dataset_2))
    target = sphere.random_uniform(n_samples=n_targets)

    labels_colors = gs.zeros([n_labels, 3])
    labels_colors[0, :] = gs.array([0, 0, 1])
    labels_colors[1, :] = gs.array([1, 0, 0])

    kde = KernelDensityEstimationClassifier(
        radius=radius,
        distance=sphere_distance,
        kernel=kernel,
        bandwidth=bandwidth,
        outlier_label='most_frequent')
    kde.fit(training_dataset, labels)
    target_labels = kde.predict(target)
    target_labels_proba = kde.predict_proba(target)

    plt.figure(0)
    ax = plt.subplot(111, projection='3d')
    plt.title('Training set')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    colors = gs.zeros([n_training_samples, 3])
    for i_sample in range(n_training_samples):
        colors[i_sample, :] = labels_colors[labels[i_sample], :]
    sphere_plot.draw_points(ax=ax, points=training_dataset, c=colors)

    plt.figure(1)
    ax = plt.subplot(111, projection='3d')
    plt.title('Classification')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    colors = gs.zeros([n_targets, 3])
    for i_target in range(n_targets):
        colors[i_target, :] = labels_colors[target_labels[i_target], :]
    sphere_plot.draw_points(ax=ax, points=target, c=colors)

    plt.figure(2)
    ax = plt.subplot(111, projection='3d')
    plt.title('Probabilistic classification')
    sphere_plot = visualization.Sphere()
    sphere_plot.draw(ax=ax)
    colors = target_labels_proba @ labels_colors
    sphere_plot.draw_points(ax=ax, points=target, c=colors)

    plt.show()