Пример #1
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor(
            {"text": sample_info["sentence2"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = sample_info["Flikr30kID"].split(".")[0]
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)

        label = LABEL_TO_INT_MAPPING[sample_info["gold_label"]]
        current_sample.targets = torch.tensor(label, dtype=torch.long)

        return current_sample
Пример #2
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)

        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Instead of using idx directly here, use sample_info to fetch
        # the features as feature_path has been dynamically added
        features = self.features_db.get(sample_info)
        if hasattr(self, "transformer_bbox_processor"):
            features["image_info_0"] = self.transformer_bbox_processor(
                features["image_info_0"])
        current_sample.update(features)

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
Пример #3
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor(
            {"text": sample_info["sentence"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = "-".join(sample_info["identifier"].split("-")[:-1])
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}-img0.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.img0 = Sample()
            current_sample.img0.update(features)

            sample_info["feature_path"] = "{}-img1.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.img1 = Sample()
            current_sample.img1.update(features)

        is_correct = 1 if sample_info["label"] == "True" else 0
        current_sample.targets = torch.tensor(is_correct, dtype=torch.long)

        return current_sample
Пример #4
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        plot = sample_info["plot"]
        if isinstance(plot, list):
            plot = plot[0]
        processed_sentence = self.text_processor({"text": plot})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"]
                )
            current_sample.update(features)

        processed = self.answer_processor({"answers": sample_info["genres"]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        return current_sample
Пример #5
0
    def _test_multiclass_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor(
            [[0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.float)
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149, 0.3491],
                [-0.8391, 0.6797, -0.3410],
                [-0.7235, 0.7220, 0.9104],
                [0.9043, 0.3078, -0.4210],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)

        sample.targets = torch.tensor([1, 2, 0, 2], dtype=torch.long)
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)
Пример #6
0
    def _test_binary_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]],
                                      dtype=torch.float)
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149],
                [-0.8391, 0.6797],
                [-0.7235, 0.7220],
                [-0.9043, 0.3078],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)

        sample.targets = torch.tensor([1, 0, 0, 1], dtype=torch.long)
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)
Пример #7
0
    def __getitem__(self, idx: int) -> Type[Sample]:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_question = self.text_processor(
            {"text": sample_info["question"]})
        current_sample.update(processed_question)
        current_sample.id = torch.tensor(int(sample_info["question_id"]),
                                         dtype=torch.int)

        # Get the first image from the set of images returned from the image_db
        image_path = self.get_image_path(sample_info["image_id"])
        current_sample.image = self.image_db.from_path(image_path)["images"][0]

        if "answers" in sample_info:
            answers = self.answer_processor(
                {"answers": sample_info["answers"]})
            current_sample.targets = answers["answers_scores"]

        return current_sample
Пример #8
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Get the first image from the set of images returned from the image_db
        current_sample.image = self.image_db[idx]["images"][0]

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
Пример #9
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        plot = sample_info["plot"]
        if isinstance(plot, list):
            plot = plot[0]
        processed_sentence = self.text_processor({"text": plot})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_images is True:
            current_sample.image = self.image_db[idx]["images"][0]

        processed = self.answer_processor({"answers": sample_info["genres"]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        return current_sample
Пример #10
0
    def __getitem__(self, idx):
        data = self.questions[idx]

        # Each call to __getitem__ from dataloader returns a Sample class object which
        # collated by our special batch collator to a SampleList which is basically
        # a attribute based batch in layman terms
        current_sample = Sample()

        question = data["question"]
        tokens = tokenize(question, keep=[";", ","], remove=["?", "."])
        processed = self.text_processor({"tokens": tokens})
        current_sample.text = processed["text"]

        processed = self.answer_processor({"answers": [data["answer"]]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        image_path = os.path.join(self.image_path, data["image_filename"])
        image = np.true_divide(Image.open(image_path).convert("RGB"), 255)
        image = image.astype(np.float32)
        current_sample.image = torch.from_numpy(image.transpose(2, 0, 1))

        return current_sample
Пример #11
0
    def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"
        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))