示例#1
0
    def test_mmf_loss(self):
        get_loss_class_mock = MagicMock(side_effect=build_loss_side_effect())
        registry.get_loss_class = get_loss_class_mock
        # Test if MMFLoss accepts empty parameters
        self.assertRaises(ValueError, losses.MMFLoss)
        self.assertTrue(
            losses.MMFLoss({
                "type": "cross_entropy"
            }).name, "cross_entropy")
        self.assertTrue(losses.MMFLoss("cross_entropy").name, "cross_entropy")
        self.assertRaises(AssertionError, losses.MMFLoss, [])
        # Multi requires dict
        self.assertRaises(AssertionError, losses.MMFLoss, "multi")

        cross_entropy = losses.MMFLoss("cross_entropy")
        cross_entropy_from_dict = losses.MMFLoss({"type": "cross_entropy"})
        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"

        output = cross_entropy(sample_list, {})
        output_from_dict = cross_entropy_from_dict(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        get_loss_class_mock.side_effect = build_loss_side_effect(1.0)
        output = cross_entropy(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        self.assertTrue(get_loss_class_mock.called)
        self.assertEqual(get_loss_class_mock.call_count, 5)
示例#2
0
    def test_mmf_dict_loss(self):
        mse_mae_loss = losses.MMFLoss("mse_mae")
        torch.manual_seed(1234)
        random_tensor = torch.rand((1, 768))

        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"
        sample_list["targets"] = random_tensor
        model_output = {"scores": random_tensor}

        output = mse_mae_loss(sample_list, model_output)

        self.assertEqual(output["val/vqa2/mse_mae/mse"].item(), 0.0)
        self.assertEqual(output["val/vqa2/mse_mae/mae"].item(), 0.0)