예제 #1
0
    def test_evaluate_regression_probe(self):
        "Basic regresson probe evaluation"

        scores = linear_probe.evaluate_probe(
            self.trained_regression_probe,
            np.random.random(
                (self.num_examples, self.num_features)).astype(np.float32),
            np.random.random((self.num_examples, )),
        )
        self.assertIn("__OVERALL__", scores)
예제 #2
0
    def test_evaluate_classification_probe(self):
        "Basic classification probe evaluation"

        scores = linear_probe.evaluate_probe(
            self.trained_probe,
            np.random.random(
                (self.num_examples, self.num_features)).astype(np.float32),
            np.random.randint(0, self.num_classes, size=self.num_examples),
        )
        self.assertIn("__OVERALL__", scores)
예제 #3
0
    def test_evaluate_probe_with_class_labels_float16(self):
        "Evaluation with class labels. Same test as before but different data dtype"

        scores = linear_probe.evaluate_probe(
            self.trained_probe,
            np.random.random(
                (self.num_examples, self.num_features)).astype(np.float16),
            np.random.randint(0, self.num_classes, size=self.num_examples),
            idx_to_class={
                0: "class0",
                1: "class1",
                2: "class2"
            },
        )
        self.assertIn("__OVERALL__", scores)
        self.assertIn("class0", scores)
        self.assertIn("class1", scores)
예제 #4
0
    def test_evaluate_probe_with_return_predictions(self):
        "Probe evaluation with returned predictions"

        y_true = np.random.randint(0, self.num_classes, size=self.num_examples)
        scores, predictions = linear_probe.evaluate_probe(
            self.trained_probe,
            np.random.random(
                (self.num_examples, self.num_features)).astype(np.float32),
            y_true,
            return_predictions=True,
        )
        self.assertIn("__OVERALL__", scores)
        self.assertIsInstance(predictions, list)
        self.assertEqual(len(predictions), self.num_examples)
        self.assertIsInstance(predictions[0], tuple)

        # Source words should be from 0 to num_examples since no source_tokens
        # were given
        self.assertListEqual([p[0] for p in predictions],
                             list(range(self.num_examples)))
        self.assertNotEqual([p[1] for p in predictions], list(y_true))