Example #1
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        if self._use_features:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])

            if self.config.get("use_image_feature_masks", False):
                current_sample.update({
                    "image_labels":
                    self.masked_region_processor(features["image_feature_0"])
                })

            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        current_sample = self._add_masked_question(sample_info, current_sample)
        if self._add_answer:
            current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
Example #2
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        if self._dataset_type != "test":
            text_processor_argument = {"tokens": sample_info["caption_tokens"]}
            processed_caption = self.text_processor(text_processor_argument)
            current_sample.text = processed_caption["text"]
            current_sample.caption_id = torch.tensor(sample_info["caption_id"],
                                                     dtype=torch.int)
            current_sample.caption_len = torch.tensor(len(
                sample_info["caption_tokens"]),
                                                      dtype=torch.int)

        current_sample.image_id = object_to_byte_tensor(
            sample_info["image_id"])

        if self._use_features:
            features = self.features_db[idx]
            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        # Add reference captions to sample
        current_sample = self.add_reference_caption(sample_info,
                                                    current_sample)

        return current_sample
Example #3
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        if "question_tokens" in sample_info:
            text_processor_argument = {
                "tokens": sample_info["question_tokens"],
                "text": sample_info["question_str"],
            }
        else:
            text_processor_argument = {"text": sample_info["question"]}

        processed_question = self.text_processor(text_processor_argument)

        current_sample.text = processed_question["text"]
        if "input_ids" in processed_question:
            current_sample.update(processed_question)

        current_sample.question_id = torch.tensor(sample_info["question_id"],
                                                  dtype=torch.int)

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = torch.tensor(sample_info["image_id"],
                                                   dtype=torch.int)
        else:
            current_sample.image_id = sample_info["image_id"]

        if "question_tokens" in sample_info:
            current_sample.text_len = torch.tensor(len(
                sample_info["question_tokens"]),
                                                   dtype=torch.int)

        if self._use_features:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)
        else:
            image_path = sample_info["image_name"] + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        # Add details for OCR like OCR bbox, vectors, tokens here
        current_sample = self.add_ocr_details(sample_info, current_sample)
        # Depending on whether we are using soft copy this can add
        # dynamic answer space
        current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
Example #4
0
File: mmbt.py Project: hahaxun/mmf
    def classify(self, image: ImageType, text: str):
        """Classifies a given image and text in it into Hateful/Non-Hateful.
        Image can be a url or a local path or you can directly pass a PIL.Image.Image
        object. Text needs to be a sentence containing all text in the image.

            >>> from multimodelity.models.mmbt import MMBT
            >>> model = MMBT.from_pretrained("mmbt.hateful_memes.images")
            >>> model.classify("some_url", "some_text")
            {"label": 0, "confidence": 0.56}

        Args:
            image (ImageType): Image to be classified
            text (str): Text in the image

        Returns:
            bool: Whether image is hateful (1) or non hateful (0)
        """
        if isinstance(image, str):
            if image.startswith("http"):
                temp_file = tempfile.NamedTemporaryFile()
                download(image,
                         *os.path.split(temp_file.name),
                         disable_tqdm=True)
                image = tv_helpers.default_loader(temp_file.name)
                temp_file.close()
            else:
                image = tv_helpers.default_loader(image)

        text = self.processor_dict["text_processor"]({"text": text})
        image = self.processor_dict["image_processor"](image)

        sample = Sample()
        sample.text = text["text"]
        if "input_ids" in text:
            sample.update(text)

        sample.image = image
        sample_list = SampleList([sample])
        device = next(self.model.parameters()).device
        sample_list = sample_list.to(device)

        output = self.model(sample_list)
        scores = nn.functional.softmax(output["scores"], dim=1)
        confidence, label = torch.max(scores, dim=1)

        return {"label": label.item(), "confidence": confidence.item()}
Example #5
0
    def __getitem__(self, idx: int) -> Type[Sample]:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_question = self.text_processor(
            {"text": sample_info["question"]})
        current_sample.update(processed_question)
        current_sample.id = torch.tensor(int(sample_info["question_id"]),
                                         dtype=torch.int)

        # Get the first image from the set of images returned from the image_db
        image_path = self.get_image_path(sample_info["image_id"])
        current_sample.image = self.image_db.from_path(image_path)["images"][0]

        if "answers" in sample_info:
            answers = self.answer_processor(
                {"answers": sample_info["answers"]})
            current_sample.targets = answers["answers_scores"]

        return current_sample
Example #6
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Get the first image from the set of images returned from the image_db
        current_sample.image = self.image_db[idx]["images"][0]

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
Example #7
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        plot = sample_info["plot"]
        if isinstance(plot, list):
            plot = plot[0]
        processed_sentence = self.text_processor({"text": plot})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_images is True:
            current_sample.image = self.image_db[idx]["images"][0]

        processed = self.answer_processor({"answers": sample_info["genres"]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        return current_sample
Example #8
0
    def test_finetune_model(self):
        self.finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255,
                                              size=(128, )).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = self.finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(self.finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
Example #9
0
    def __getitem__(self, idx):
        data = self.questions[idx]

        # Each call to __getitem__ from dataloader returns a Sample class object which
        # collated by our special batch collator to a SampleList which is basically
        # a attribute based batch in layman terms
        current_sample = Sample()

        question = data["question"]
        tokens = tokenize(question, keep=[";", ","], remove=["?", "."])
        processed = self.text_processor({"tokens": tokens})
        current_sample.text = processed["text"]

        processed = self.answer_processor({"answers": [data["answer"]]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        image_path = os.path.join(self.image_path, data["image_filename"])
        image = np.true_divide(Image.open(image_path).convert("RGB"), 255)
        image = image.astype(np.float32)
        current_sample.image = torch.from_numpy(image.transpose(2, 0, 1))

        return current_sample
Example #10
0
    def test_modal_end_token(self):
        self.finetune_model.eval()

        # Suppose 0 for <cls>, 1 for <pad> 2 for <sep>
        CLS = 0
        PAD = 1
        SEP = 2
        size = 128

        input_ids = torch.randint(low=0, high=30255, size=(size, )).long()
        input_mask = torch.ones(size).long()

        input_ids[0] = CLS
        length = torch.randint(low=2, high=size - 1, size=(1, ))
        input_ids[length] = SEP
        input_ids[length + 1:] = PAD
        input_mask[length + 1:] = 0

        test_sample = Sample()
        test_sample.input_ids = input_ids.clone()
        test_sample.input_mask = input_mask.clone()
        test_sample.segment_ids = torch.zeros(size).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample])

        mmbt_base = self.finetune_model.model.bert
        with torch.no_grad():
            actual_modal_end_token = mmbt_base.extract_modal_end_token(
                test_sample_list)

        expected_modal_end_token = torch.zeros([1]).fill_(SEP).long()
        self.assertTrue(
            torch.equal(actual_modal_end_token, expected_modal_end_token))
        self.assertTrue(
            torch.equal(test_sample_list.input_ids[0, :-1], input_ids[1:]))
        self.assertTrue(
            torch.equal(test_sample_list.input_mask[0, :-1], input_mask[1:]))
Example #11
0
    def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"
        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))