コード例 #1
0
    def test_predict(self):
        model = MultitaskClassifier([self.task1])
        results = model.predict(self.dataloader)
        self.assertEqual(sorted(list(results.keys())), ["golds", "probs"])
        np.testing.assert_array_equal(
            results["golds"]["task1"],
            self.dataloader.dataset.Y_dict["task1"].numpy())
        np.testing.assert_array_equal(results["probs"]["task1"],
                                      np.ones((NUM_EXAMPLES, 2)) * 0.5)

        results = model.predict(self.dataloader, return_preds=True)
        self.assertEqual(sorted(list(results.keys())),
                         ["golds", "preds", "probs"])
        # deterministic random tie breaking alternates predicted labels
        np.testing.assert_array_equal(
            results["preds"]["task1"],
            np.array([0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]),
        )
コード例 #2
0
    def test_remapped_labels(self):
        # Test additional label keys in the Y_dict
        # Without remapping, model should ignore them
        task_name = self.task1.name
        X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)])
        Y = torch.ones(NUM_EXAMPLES).long()

        Y_dict = {task_name: Y, "other_task": Y}
        dataset = DictDataset(name="dataset",
                              split="train",
                              X_dict={"data": X},
                              Y_dict=Y_dict)
        dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE)

        model = MultitaskClassifier([self.task1])
        loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                     dataset.Y_dict)
        self.assertIn("task1", loss_dict)

        # Test setting without remapping
        results = model.predict(dataloader)
        self.assertIn("task1", results["golds"])
        self.assertNotIn("other_task", results["golds"])
        scores = model.score([dataloader])
        self.assertIn("task1/dataset/train/accuracy", scores)
        self.assertNotIn("other_task/dataset/train/accuracy", scores)

        # Test remapped labelsets
        results = model.predict(dataloader,
                                remap_labels={"other_task": task_name})
        self.assertIn("task1", results["golds"])
        self.assertIn("other_task", results["golds"])
        results = model.score([dataloader],
                              remap_labels={"other_task": task_name})
        self.assertIn("task1/dataset/train/accuracy", results)
        self.assertIn("other_task/dataset/train/accuracy", results)