Пример #1
0
 def test_clustering_result_value_errors_loss_label_only_one_init(self):
     centers = np.zeros((2, 3))
     datapoints = np.zeros((4, 3))
     data = clustering_params.Data(datapoints=datapoints, radius=2)
     cluster_labels = np.array([0, 0, 1, 1], dtype=int)
     loss = 1.0
     with self.assertRaises(ValueError):
         clustering_algorithm.ClusteringResult(data, centers,
                                               cluster_labels)
     with self.assertRaises(ValueError):
         clustering_algorithm.ClusteringResult(data, centers, loss=loss)
Пример #2
0
 def test_clustering_result_value_errors_unequal_points(self):
     centers = np.array([[0, 0, 0], [1, 1, 1]])
     datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]])
     labels = np.array([0, 1], dtype=int)
     data = clustering_params.Data(datapoints=datapoints, radius=200)
     with self.assertRaises(ValueError):
         clustering_algorithm.ClusteringResult(data,
                                               centers,
                                               labels,
                                               loss=1.0)
Пример #3
0
 def test_value_error_no_true_labels(self):
     datapoints, radius = np.zeros(shape=(6, 4)), 1.0
     data = clustering_params.Data(datapoints, radius)
     centers = np.zeros(shape=(3, 4))
     cluster_labels = np.array([0, 0, 1, 1, 2, 2])
     clustering_result = clustering_algorithm.ClusteringResult(
         data, centers, cluster_labels, loss=1.0)
     with self.assertRaises(ValueError):
         clustering_result.cross_label_histogram()
     with self.assertRaises(ValueError):
         clustering_result.get_clustering_metrics()
Пример #4
0
 def test_clustering_result_value_errors_labels_out_of_bounds(self):
     centers = np.array([[0, 0, 0], [1, 1, 1]])
     datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]])
     data = clustering_params.Data(datapoints=datapoints, radius=200)
     for labels in [
             np.array([-1, 0, 1], dtype=int),
             np.array([0, 1, 2], dtype=int),
             np.array([0, 1, 1.1])
     ]:
         with self.assertRaises(ValueError):
             clustering_algorithm.ClusteringResult(data,
                                                   centers,
                                                   labels,
                                                   loss=1.0)
Пример #5
0
    def test_get_clustering_result(self):
        centers = np.array([[0, 0, 0], [100, 100, 100]])
        datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]])
        data = clustering_params.Data(datapoints=datapoints, radius=200)

        clustering_result = clustering_algorithm.ClusteringResult(
            data, centers)

        self.assertLen(data.datapoints, 3)
        for i, datapoint in enumerate(clustering_result.data.datapoints):
            self.assertSequenceAlmostEqual(datapoints[i], datapoint)
        self.assertLen(centers, 2)
        for i, center in enumerate(clustering_result.centers):
            self.assertSequenceAlmostEqual(centers[i], center)

        self.assertListEqual(list(clustering_result.labels), [0, 1, 0])
        self.assertAlmostEqual(clustering_result.loss, 37)
Пример #6
0
    def test_get_clustering_metrics(self):
        datapoints, radius = np.zeros(shape=(6, 4)), 1.0
        labels = np.array([0, 0, 0, 1, 1, 1])
        data = clustering_params.Data(datapoints, radius, labels)
        centers = np.zeros(shape=(3, 4))
        cluster_labels = np.array([0, 0, 1, 1, 2, 2])
        clustering_result = clustering_algorithm.ClusteringResult(
            data, centers, cluster_labels, loss=1.0)
        clustering_metrics = clustering_result.get_clustering_metrics()

        expected_cross_label_histogram = np.array([[2, 0], [1, 1], [0, 2]],
                                                  dtype=int)
        self.assertTrue((clustering_metrics.cross_label_histogram ==
                         expected_cross_label_histogram).all())
        self.assertEqual(clustering_metrics.num_points, 6)
        self.assertEqual(clustering_metrics.dominant_label_correct_count, 5)
        self.assertAlmostEqual(clustering_metrics.dominant_label_accuracy,
                               5 / 6)
        self.assertEqual(clustering_metrics.true_pairs, 6)
        self.assertEqual(clustering_metrics.true_nonmatch_count, 4)
        self.assertAlmostEqual(clustering_metrics.true_nonmatch_frac, 4 / 6)
        self.assertEqual(clustering_metrics.false_pairs, 9)
        self.assertEqual(clustering_metrics.false_match_count, 1)
        self.assertAlmostEqual(clustering_metrics.false_match_frac, 1 / 9)