def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"

        test_sample_list = test_sample_list.to(get_current_device())
        cnn_lstm = cnn_lstm.to(get_current_device())

        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))
Exemple #2
0
    def test_mmf_loss(self):
        get_loss_class_mock = MagicMock(side_effect=build_loss_side_effect())
        registry.get_loss_class = get_loss_class_mock
        # Test if MMFLoss accepts empty parameters
        self.assertRaises(ValueError, losses.MMFLoss)
        self.assertTrue(
            losses.MMFLoss({
                "type": "cross_entropy"
            }).name, "cross_entropy")
        self.assertTrue(losses.MMFLoss("cross_entropy").name, "cross_entropy")
        self.assertRaises(AssertionError, losses.MMFLoss, [])
        # Multi requires dict
        self.assertRaises(AssertionError, losses.MMFLoss, "multi")

        cross_entropy = losses.MMFLoss("cross_entropy")
        cross_entropy_from_dict = losses.MMFLoss({"type": "cross_entropy"})
        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"

        output = cross_entropy(sample_list, {})
        output_from_dict = cross_entropy_from_dict(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        get_loss_class_mock.side_effect = build_loss_side_effect(1.0)
        output = cross_entropy(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        self.assertTrue(get_loss_class_mock.called)
        self.assertEqual(get_loss_class_mock.call_count, 5)
    def test_pretrained_model(self):
        sample_list = SampleList()

        sample_list.add_field(
            "input_ids",
            torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long(),
        )
        sample_list.add_field("input_mask", torch.ones((1, 128)).long())
        sample_list.add_field("segment_ids", torch.zeros(1, 128).long())
        sample_list.add_field("image_feature_0", torch.rand((1, 100, 2048)).float())
        sample_list.add_field(
            "lm_label_ids", torch.zeros((1, 128), dtype=torch.long).fill_(-1)
        )

        self.pretrain_model.eval()
        self.pretrain_model = self.pretrain_model.to(get_current_device())
        sample_list = sample_list.to(get_current_device())

        sample_list.dataset_name = "random"
        sample_list.dataset_type = "test"
        with torch.no_grad():
            model_output = self.pretrain_model(sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("random/test/masked_lm_loss" in model_output["losses"])
        self.assertTrue(model_output["losses"]["random/test/masked_lm_loss"] == 0)
Exemple #4
0
    def test_mmf_dict_loss(self):
        mse_mae_loss = losses.MMFLoss("mse_mae")
        torch.manual_seed(1234)
        random_tensor = torch.rand((1, 768))

        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"
        sample_list["targets"] = random_tensor
        model_output = {"scores": random_tensor}

        output = mse_mae_loss(sample_list, model_output)

        self.assertEqual(output["val/vqa2/mse_mae/mse"].item(), 0.0)
        self.assertEqual(output["val/vqa2/mse_mae/mae"].item(), 0.0)
    def __call__(self, batch):
        # Create and return sample list with proper name
        # and type set if it is already not a sample list
        # (case of batched iterators)
        sample_list = batch
        if (
            # Check if batch is a list before checking batch[0]
            # or len as sometimes batch is already SampleList
            isinstance(batch, list)
            and len(batch) == 1
            and isinstance(batch[0], SampleList)
        ):
            sample_list = batch[0]
        elif not isinstance(batch, SampleList):
            sample_list = SampleList(batch)

        sample_list.dataset_name = self._dataset_name
        sample_list.dataset_type = self._dataset_type
        return sample_list
Exemple #6
0
    def _get_sample_list(self):
        bs = 8
        num_feats = 70

        class MockObj:
            pass

        mock_input = MockObj()
        mock_vinvl_input_tensors(mock_input, bs=bs, num_feats=num_feats)

        input_mask = torch.ones_like(mock_input.input_ids)
        max_features = torch.ones((bs, num_feats)) * num_feats
        bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
        image_height = torch.randint(100, 300, (bs,))
        image_width = torch.randint(100, 300, (bs,))
        image_info = {
            "max_features": max_features,
            "bbox": bbox,
            "image_height": image_height,
            "image_width": image_width,
        }

        sample_list = SampleList()
        sample_list.add_field("input_ids", mock_input.input_ids)
        sample_list.add_field("input_ids_corrupt", mock_input.input_ids)
        sample_list.add_field("input_ids_masked", mock_input.input_ids)
        sample_list.add_field("image_feature_0", mock_input.img_feats)
        sample_list.add_field("image_info_0", image_info)
        sample_list.add_field("input_mask", input_mask)
        sample_list.add_field("input_mask_corrupt", input_mask)
        sample_list.add_field("segment_ids", mock_input.token_type_ids)
        sample_list.add_field("segment_ids_corrupt", mock_input.token_type_ids)
        sample_list.add_field("labels", mock_input.labels)
        sample_list.add_field("contrastive_labels", mock_input.contrastive_labels)
        sample_list.add_field("lm_label_ids", mock_input.lm_label_ids)
        sample_list = sample_list.to(get_current_device())
        sample_list.dataset_name = "test"
        sample_list.dataset_type = "test"
        return sample_list
Exemple #7
0
    def test_pretrained_model(self):
        sample_list = SampleList()

        sample_list.add_field(
            "input_ids",
            torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long(),
        )
        sample_list.add_field("input_mask", torch.ones((1, 128)).long())
        sample_list.add_field("segment_ids", torch.zeros(1, 128).long())
        sample_list.add_field("image", torch.rand((1, 3, 224, 224)).float())
        sample_list.add_field("targets", torch.rand((1, 3129)).float())

        self.pretrain_model.eval()
        self.pretrain_model = self.pretrain_model.to(get_current_device())
        sample_list = sample_list.to(get_current_device())

        sample_list.dataset_name = "test"
        sample_list.dataset_type = "test"
        with torch.no_grad():
            model_output = self.pretrain_model(sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("test/test/logit_bce" in model_output["losses"])
Exemple #8
0
 def __call__(self, batch):
     # Create and return sample list with proper name and type set
     sample_list = SampleList(batch)
     sample_list.dataset_name = self._dataset_name
     sample_list.dataset_type = self._dataset_type
     return sample_list