Пример #1
0
    def _build_report(self):
        tensor_a = torch.tensor([[1, 2, 3, 4], [2, 3, 4, 5]])
        sample_list = SampleList()
        sample_list.add_field("a", tensor_a)
        model_output = {"scores": torch.rand(2, 2)}

        report = Report(sample_list, model_output)
        return report
Пример #2
0
    def test_pretrained_model(self):
        sample_list = SampleList()

        sample_list.add_field(
            "input_ids",
            torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long(),
        )
        sample_list.add_field("input_mask", torch.ones((1, 128)).long())
        sample_list.add_field("segment_ids", torch.zeros(1, 128).long())
        sample_list.add_field("image_feature_0", torch.rand((1, 100, 2048)).float())
        sample_list.add_field(
            "lm_label_ids", torch.zeros((1, 128), dtype=torch.long).fill_(-1)
        )

        self.pretrain_model.eval()
        self.pretrain_model = self.pretrain_model.to(get_current_device())
        sample_list = sample_list.to(get_current_device())

        sample_list.dataset_name = "random"
        sample_list.dataset_type = "test"
        with torch.no_grad():
            model_output = self.pretrain_model(sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("random/test/masked_lm_loss" in model_output["losses"])
        self.assertTrue(model_output["losses"]["random/test/masked_lm_loss"] == 0)
Пример #3
0
    def test_convert_batch_to_sample_list(self):
        # Test list conversion
        batch = [{"a": torch.tensor([1.0, 1.0])}, {"a": torch.tensor([2.0, 2.0])}]
        sample_list = convert_batch_to_sample_list(batch)
        expected_a = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
        self.assertTrue(torch.equal(expected_a, sample_list.a))

        # Test single element list, samplelist
        sample_list = SampleList()
        sample_list.add_field("a", expected_a)
        parsed_sample_list = convert_batch_to_sample_list([sample_list])
        self.assertTrue(isinstance(parsed_sample_list, SampleList))
        self.assertTrue("a" in parsed_sample_list)
        self.assertTrue(torch.equal(expected_a, parsed_sample_list.a))

        # Test no tensor field
        batch = [{"a": [1]}, {"a": [2]}]
        sample_list = convert_batch_to_sample_list(batch)
        self.assertTrue(sample_list.a, [[1], [2]])
Пример #4
0
    def test_pretrained_model(self):
        sample_list = SampleList()

        sample_list.add_field(
            "input_ids",
            torch.randint(low=0, high=BERT_VOCAB_SIZE, size=(1, 128)).long(),
        )
        sample_list.add_field("input_mask", torch.ones((1, 128)).long())
        sample_list.add_field("segment_ids", torch.zeros(1, 128).long())
        sample_list.add_field("image", torch.rand((1, 3, 224, 224)).float())
        sample_list.add_field("targets", torch.rand((1, 3129)).float())

        self.pretrain_model.eval()
        self.pretrain_model = self.pretrain_model.to(get_current_device())
        sample_list = sample_list.to(get_current_device())

        sample_list.dataset_name = "test"
        sample_list.dataset_type = "test"
        with torch.no_grad():
            model_output = self.pretrain_model(sample_list)

        self.assertTrue("losses" in model_output)
        self.assertTrue("test/test/logit_bce" in model_output["losses"])
Пример #5
0
    def test_finetune_model(self):
        self.finetune_model.eval()
        sample_list = SampleList()

        sample_list.add_field(
            "input_ids",
            torch.randint(low=0, high=30255, size=(1, 128)).long())
        sample_list.add_field("input_mask", torch.ones((1, 128)).long())
        sample_list.add_field("segment_ids", torch.zeros(1, 128).long())
        sample_list.add_field("image_feature_0",
                              torch.rand((1, 100, 2048)).float())

        with torch.no_grad():
            model_output = self.finetune_model(sample_list)

        script_model = torch.jit.script(self.finetune_model)
        with torch.no_grad():
            script_output = script_model(sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
Пример #6
0
    def _get_sample_list(self):
        bs = 8
        num_feats = 100
        max_sentence_len = 25
        img_dim = 2048
        cls_dim = 3129
        input_ids = torch.ones((bs, max_sentence_len), dtype=torch.long)
        input_mask = torch.ones((bs, max_sentence_len), dtype=torch.long)
        image_feat = torch.rand((bs, num_feats, img_dim))
        position_ids = (torch.arange(
            0, max_sentence_len, dtype=torch.long,
            device=image_feat.device).unsqueeze(0).expand(bs, -1))
        img_pos_feat = torch.rand((bs, num_feats, 7))
        attention_mask = torch.zeros((bs, max_sentence_len + num_feats),
                                     dtype=torch.long)
        image_mask = torch.zeros((bs, num_feats), dtype=torch.long)
        targets = torch.rand((bs, cls_dim))

        sample_list = SampleList()
        sample_list.add_field("input_ids", input_ids)
        sample_list.add_field("input_mask", input_mask)
        sample_list.add_field("image_feat", image_feat)
        sample_list.add_field("img_pos_feat", img_pos_feat)
        sample_list.add_field("attention_mask", attention_mask)
        sample_list.add_field("image_mask", image_mask)
        sample_list.add_field("targets", targets)
        sample_list.add_field("dataset_name", "test")
        sample_list.add_field("dataset_type", "test")
        sample_list.add_field("position_ids", position_ids)
        sample_list.to(get_current_device())

        return sample_list
Пример #7
0
    def _get_sample_list(self):
        bs = 8
        num_feats = 100
        max_sentence_len = 25
        img_dim = 2048
        vqa_cls_dim = 3129
        input_ids = torch.ones((bs, max_sentence_len), dtype=torch.long)
        input_mask = torch.ones((bs, max_sentence_len), dtype=torch.long)
        img_feat = torch.rand((bs, num_feats, img_dim))

        max_features = torch.ones((bs, num_feats)) * num_feats
        bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
        image_height = torch.randint(100, 300, (bs, ))
        image_width = torch.randint(100, 300, (bs, ))
        image_info = {
            "max_features": max_features,
            "bbox": bbox,
            "image_height": image_height,
            "image_width": image_width,
        }
        targets = torch.rand((bs, vqa_cls_dim))
        is_correct = torch.ones((bs, ), dtype=torch.long)

        sample_list = SampleList()
        sample_list.add_field("input_ids", input_ids)
        sample_list.add_field("image_feature_0", img_feat)
        sample_list.add_field("input_mask", input_mask)
        sample_list.add_field("image_info_0", image_info)
        sample_list.add_field("targets", targets)
        sample_list.add_field("is_correct", is_correct)
        sample_list = sample_list.to(get_current_device())
        return sample_list
Пример #8
0
    def _get_sample_list(self):
        bs = 8
        num_feats = 70

        class MockObj:
            pass

        mock_input = MockObj()
        mock_vinvl_input_tensors(mock_input, bs=bs, num_feats=num_feats)

        input_mask = torch.ones_like(mock_input.input_ids)
        max_features = torch.ones((bs, num_feats)) * num_feats
        bbox = torch.randint(50, 200, (bs, num_feats, 4)).float()
        image_height = torch.randint(100, 300, (bs,))
        image_width = torch.randint(100, 300, (bs,))
        image_info = {
            "max_features": max_features,
            "bbox": bbox,
            "image_height": image_height,
            "image_width": image_width,
        }

        sample_list = SampleList()
        sample_list.add_field("input_ids", mock_input.input_ids)
        sample_list.add_field("input_ids_corrupt", mock_input.input_ids)
        sample_list.add_field("input_ids_masked", mock_input.input_ids)
        sample_list.add_field("image_feature_0", mock_input.img_feats)
        sample_list.add_field("image_info_0", image_info)
        sample_list.add_field("input_mask", input_mask)
        sample_list.add_field("input_mask_corrupt", input_mask)
        sample_list.add_field("segment_ids", mock_input.token_type_ids)
        sample_list.add_field("segment_ids_corrupt", mock_input.token_type_ids)
        sample_list.add_field("labels", mock_input.labels)
        sample_list.add_field("contrastive_labels", mock_input.contrastive_labels)
        sample_list.add_field("lm_label_ids", mock_input.lm_label_ids)
        sample_list = sample_list.to(get_current_device())
        sample_list.dataset_name = "test"
        sample_list.dataset_type = "test"
        return sample_list