def test_clustering_result_value_errors_loss_label_only_one_init(self): centers = np.zeros((2, 3)) datapoints = np.zeros((4, 3)) data = clustering_params.Data(datapoints=datapoints, radius=2) cluster_labels = np.array([0, 0, 1, 1], dtype=int) loss = 1.0 with self.assertRaises(ValueError): clustering_algorithm.ClusteringResult(data, centers, cluster_labels) with self.assertRaises(ValueError): clustering_algorithm.ClusteringResult(data, centers, loss=loss)
def test_clustering_result_value_errors_unequal_points(self): centers = np.array([[0, 0, 0], [1, 1, 1]]) datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]]) labels = np.array([0, 1], dtype=int) data = clustering_params.Data(datapoints=datapoints, radius=200) with self.assertRaises(ValueError): clustering_algorithm.ClusteringResult(data, centers, labels, loss=1.0)
def test_value_error_no_true_labels(self): datapoints, radius = np.zeros(shape=(6, 4)), 1.0 data = clustering_params.Data(datapoints, radius) centers = np.zeros(shape=(3, 4)) cluster_labels = np.array([0, 0, 1, 1, 2, 2]) clustering_result = clustering_algorithm.ClusteringResult( data, centers, cluster_labels, loss=1.0) with self.assertRaises(ValueError): clustering_result.cross_label_histogram() with self.assertRaises(ValueError): clustering_result.get_clustering_metrics()
def test_clustering_result_value_errors_labels_out_of_bounds(self): centers = np.array([[0, 0, 0], [1, 1, 1]]) datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]]) data = clustering_params.Data(datapoints=datapoints, radius=200) for labels in [ np.array([-1, 0, 1], dtype=int), np.array([0, 1, 2], dtype=int), np.array([0, 1, 1.1]) ]: with self.assertRaises(ValueError): clustering_algorithm.ClusteringResult(data, centers, labels, loss=1.0)
def test_get_clustering_result(self): centers = np.array([[0, 0, 0], [100, 100, 100]]) datapoints = np.array([[1, 0, 1], [101, 101, 99], [4, 0, 4]]) data = clustering_params.Data(datapoints=datapoints, radius=200) clustering_result = clustering_algorithm.ClusteringResult( data, centers) self.assertLen(data.datapoints, 3) for i, datapoint in enumerate(clustering_result.data.datapoints): self.assertSequenceAlmostEqual(datapoints[i], datapoint) self.assertLen(centers, 2) for i, center in enumerate(clustering_result.centers): self.assertSequenceAlmostEqual(centers[i], center) self.assertListEqual(list(clustering_result.labels), [0, 1, 0]) self.assertAlmostEqual(clustering_result.loss, 37)
def test_get_clustering_metrics(self): datapoints, radius = np.zeros(shape=(6, 4)), 1.0 labels = np.array([0, 0, 0, 1, 1, 1]) data = clustering_params.Data(datapoints, radius, labels) centers = np.zeros(shape=(3, 4)) cluster_labels = np.array([0, 0, 1, 1, 2, 2]) clustering_result = clustering_algorithm.ClusteringResult( data, centers, cluster_labels, loss=1.0) clustering_metrics = clustering_result.get_clustering_metrics() expected_cross_label_histogram = np.array([[2, 0], [1, 1], [0, 2]], dtype=int) self.assertTrue((clustering_metrics.cross_label_histogram == expected_cross_label_histogram).all()) self.assertEqual(clustering_metrics.num_points, 6) self.assertEqual(clustering_metrics.dominant_label_correct_count, 5) self.assertAlmostEqual(clustering_metrics.dominant_label_accuracy, 5 / 6) self.assertEqual(clustering_metrics.true_pairs, 6) self.assertEqual(clustering_metrics.true_nonmatch_count, 4) self.assertAlmostEqual(clustering_metrics.true_nonmatch_frac, 4 / 6) self.assertEqual(clustering_metrics.false_pairs, 9) self.assertEqual(clustering_metrics.false_match_count, 1) self.assertAlmostEqual(clustering_metrics.false_match_frac, 1 / 9)