def test_mmf_loss(self): get_loss_class_mock = MagicMock(side_effect=build_loss_side_effect()) registry.get_loss_class = get_loss_class_mock # Test if MMFLoss accepts empty parameters self.assertRaises(ValueError, losses.MMFLoss) self.assertTrue( losses.MMFLoss({ "type": "cross_entropy" }).name, "cross_entropy") self.assertTrue(losses.MMFLoss("cross_entropy").name, "cross_entropy") self.assertRaises(AssertionError, losses.MMFLoss, []) # Multi requires dict self.assertRaises(AssertionError, losses.MMFLoss, "multi") cross_entropy = losses.MMFLoss("cross_entropy") cross_entropy_from_dict = losses.MMFLoss({"type": "cross_entropy"}) sample_list = SampleList() sample_list.dataset_type = "val" sample_list.dataset_name = "vqa2" output = cross_entropy(sample_list, {}) output_from_dict = cross_entropy_from_dict(sample_list, {}) self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)}) self.assertEqual(output_from_dict, output) get_loss_class_mock.side_effect = build_loss_side_effect(1.0) output = cross_entropy(sample_list, {}) self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)}) self.assertEqual(output_from_dict, output) self.assertTrue(get_loss_class_mock.called) self.assertEqual(get_loss_class_mock.call_count, 5)
def test_mmf_dict_loss(self): mse_mae_loss = losses.MMFLoss("mse_mae") torch.manual_seed(1234) random_tensor = torch.rand((1, 768)) sample_list = SampleList() sample_list.dataset_type = "val" sample_list.dataset_name = "vqa2" sample_list["targets"] = random_tensor model_output = {"scores": random_tensor} output = mse_mae_loss(sample_list, model_output) self.assertEqual(output["val/vqa2/mse_mae/mse"].item(), 0.0) self.assertEqual(output["val/vqa2/mse_mae/mae"].item(), 0.0)