def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor(
            {"text": sample_info["sentence"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = "-".join(sample_info["identifier"].split("-")[:-1])
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}-img0.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.img0 = Sample()
            current_sample.img0.update(features)

            sample_info["feature_path"] = "{}-img1.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.img1 = Sample()
            current_sample.img1.update(features)

        is_correct = 1 if sample_info["label"] == "True" else 0
        current_sample.targets = torch.tensor(is_correct, dtype=torch.long)

        return current_sample
Beispiel #2
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        text_processor_argument = {"text": sample_info["question_str"]}
        processed_question = self.text_processor(text_processor_argument)
        current_sample.text = processed_question["text"]
        if "input_ids" in processed_question:
            current_sample.update(processed_question)

        current_sample.question_id = torch.tensor(sample_info["question_id"],
                                                  dtype=torch.int)

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = torch.tensor(sample_info["image_id"],
                                                   dtype=torch.int)
        else:
            current_sample.image_id = sample_info["image_id"]

        if self._use_features is True:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)

        # Depending on whether we are using soft copy this can add
        # dynamic answer space
        current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
Beispiel #3
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        if self._dataset_type != "test":
            text_processor_argument = {"tokens": sample_info["caption_tokens"]}
            processed_caption = self.text_processor(text_processor_argument)
            current_sample.text = processed_caption["text"]
            current_sample.caption_id = torch.tensor(sample_info["caption_id"],
                                                     dtype=torch.int)
            current_sample.caption_len = torch.tensor(len(
                sample_info["caption_tokens"]),
                                                      dtype=torch.int)

        current_sample.image_id = object_to_byte_tensor(
            sample_info["image_id"])

        if self._use_features:
            features = self.features_db[idx]
            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        # Add reference captions to sample
        current_sample = self.add_reference_caption(sample_info,
                                                    current_sample)

        return current_sample
Beispiel #4
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)

        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Instead of using idx directly here, use sample_info to fetch
        # the features as feature_path has been dynamically added
        features = self.features_db.get(sample_info)
        if hasattr(self, "transformer_bbox_processor"):
            features["image_info_0"] = self.transformer_bbox_processor(
                features["image_info_0"])
        current_sample.update(features)

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor(
            {"text": sample_info["sentence2"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = sample_info["Flikr30kID"].split(".")[0]
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)

        label = LABEL_TO_INT_MAPPING[sample_info["gold_label"]]
        current_sample.targets = torch.tensor(label, dtype=torch.long)

        return current_sample
    def __getitem__(self, idx: int) -> Sample:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_caption = self.masked_token_processor({
            "text_a":
            sample_info["caption"],
            "text_b":
            "",
            "is_correct":
            True
        })
        current_sample.update(processed_caption)
        current_sample.image_id = sample_info["image_id"]
        current_sample.feature_path = sample_info["feature_path"]

        # Get the image features
        if self._use_features:
            features = self.features_db[idx]
            image_info_0 = features["image_info_0"]
            if image_info_0 and "image_id" in image_info_0.keys():
                image_info_0["feature_path"] = image_info_0["image_id"]
                image_info_0.pop("image_id")
            current_sample.update(features)

        return current_sample
    def _load_objects(self, idx):
        image_info = self._get_image_info(idx)
        image_height = image_info["height"]
        image_width = image_info["width"]
        object_map = {}
        objects = []

        for obj in image_info["objects"]:
            obj["synsets"] = self.synset_processor({"tokens":
                                                    obj["synsets"]})["text"]
            obj["names"] = self.name_processor({"tokens":
                                                obj["names"]})["text"]
            obj["height"] = obj["h"] / image_height
            obj.pop("h")
            obj["width"] = obj["w"] / image_width
            obj.pop("w")
            obj["y"] /= image_height
            obj["x"] /= image_width
            obj["attributes"] = self.attribute_processor(
                {"tokens": obj["attributes"]})["text"]
            obj = Sample(obj)
            object_map[obj["object_id"]] = obj
            objects.append(obj)
        objects = SampleList(objects)

        return objects, object_map
Beispiel #8
0
    def add_ocr_details(self, sample_info, sample):
        if self.use_ocr:
            # Preprocess OCR tokens
            ocr_tokens = [
                self.ocr_token_processor({"text": token})["text"]
                for token in sample_info["ocr_tokens"]
            ]
            # Get embeddings for tokens
            context = self.context_processor({"tokens": ocr_tokens})
            sample.context = context["text"]
            sample.context_tokens = context["tokens"]
            sample.context_feature_0 = context["text"]
            sample.context_info_0 = Sample()
            sample.context_info_0.max_features = context["length"]

            order_vectors = torch.eye(len(sample.context_tokens))
            order_vectors[context["length"] :] = 0
            sample.order_vectors = order_vectors

        if self.use_ocr_info and "ocr_info" in sample_info:
            sample.ocr_bbox = self.bbox_processor({"info": sample_info["ocr_info"]})[
                "bbox"
            ]

        return sample
    def test_sample_working(self):
        initial = Sample()
        initial.x = 1
        initial["y"] = 2
        # Assert setter and getter
        self.assertEqual(initial.x, 1)
        self.assertEqual(initial["x"], 1)
        self.assertEqual(initial.y, 2)
        self.assertEqual(initial["y"], 2)

        update_dict = {"a": 3, "b": {"c": 4}}

        initial.update(update_dict)
        self.assertEqual(initial.a, 3)
        self.assertEqual(initial["a"], 3)
        self.assertEqual(initial.b.c, 4)
        self.assertEqual(initial["b"].c, 4)
    def _test_multilabel_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor(
            [[0, 1, 1], [1, 0, 1], [1, 0, 1], [0, 0, 1]], dtype=torch.float)
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149, 0.3491],
                [-0.8391, 0.6797, -0.3410],
                [-0.7235, 0.7220, 0.9104],
                [0.9043, 0.3078, -0.4210],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)
Beispiel #11
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        # breaking change from VQA2Dataset: load question_id
        current_sample.question_id = torch.tensor(
            sample_info["question_id"], dtype=torch.int
        )

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = str(sample_info["image_id"])
        else:
            current_sample.image_id = sample_info["image_id"]
        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)

        current_sample = self.add_sample_details(sample_info, current_sample)
        current_sample = self.add_answer_info(sample_info, current_sample)

        # only the 'max_features' key is needed
        # pop other keys to minimize data loading overhead
        if hasattr(current_sample, "image_info_0"):
            for k in list(current_sample.image_info_0):
                if k != "max_features":
                    current_sample.image_info_0.pop(k)
        if hasattr(current_sample, "image_info_1"):
            for k in list(current_sample.image_info_1):
                if k != "max_features":
                    current_sample.image_info_1.pop(k)

        return current_sample
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        if self._use_features:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])

            if self.config.get("use_image_feature_masks", False):
                current_sample.update({
                    "image_labels":
                    self.masked_region_processor(features["image_feature_0"])
                })

            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        current_sample = self._add_masked_question(sample_info, current_sample)
        if self._add_answer:
            current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
Beispiel #13
0
    def __call__(self, item):
        texts = item["text"]
        if not isinstance(texts, list):
            texts = [texts]

        processed = []
        for idx, text in enumerate(texts):
            sample = Sample()
            processed_text = self.tokenizer({"text": text})
            sample.update(processed_text)
            sample.segment_ids.fill_(idx)
            processed.append(sample)
        # Use SampleList to convert list of tensors to stacked tensors
        processed = SampleList(processed)
        if self.fusion_strategy == "concat":
            processed.input_ids = processed.input_ids.view(-1)
            processed.input_mask = processed.input_mask.view(-1)
            processed.segment_ids = processed.segment_ids.view(-1)
            processed.lm_label_ids = processed.lm_label_ids.view(-1)
        return processed.to_dict()
    def _test_binary_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor([[0, 1], [1, 0], [1, 0], [0, 1]],
                                      dtype=torch.float)
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149],
                [-0.8391, 0.6797],
                [-0.7235, 0.7220],
                [-0.9043, 0.3078],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)

        sample.targets = torch.tensor([1, 0, 0, 1], dtype=torch.long)
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)
    def test_beam_search(self):
        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES)
        model_config = self.config.model_config.butd
        model = TestDecoderModel(model_config, vocab)
        model.build()
        model.eval()

        expected_tokens = {
            1: [1.0, 23.0, 1.0, 24.0, 29.0, 37.0, 40.0, 17.0, 29.0, 2.0],
            2: [1.0, 0.0, 8.0, 1.0, 28.0, 25.0, 2.0],
            8: [1.0, 34.0, 1.0, 13.0, 1.0, 2.0],
            16: [1.0, 25.0, 18.0, 2.0],
        }

        for batch_size in [1, 2, 8, 16]:
            samples = []
            for _ in range(batch_size):
                sample = Sample()
                sample.dataset_name = "coco"
                sample.dataset_type = "test"
                sample.image_feature_0 = torch.randn(100, 2048)
                sample.answers = torch.zeros((5, 10), dtype=torch.long)
                samples.append(sample)

            sample_list = SampleList(samples)
            tokens = model(sample_list)["captions"]
            self.assertEqual(np.trim_zeros(tokens[0].tolist()),
                             expected_tokens[batch_size])
Beispiel #16
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self._preprocess_answer(sample_info)
        sample_info["question_id"] = sample_info["id"]
        current_sample = Sample()

        if self._use_features is True:
            features = self.features_db[idx]

            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])

            if self.config.get("use_image_feature_masks", False):
                current_sample.update({
                    "image_labels":
                    self.masked_region_processor(features["image_feature_0"])
                })

            current_sample.update(features)

        current_sample = self._add_masked_question(sample_info, current_sample)
        if self._add_answer:
            current_sample = self.add_answer_info(sample_info, current_sample)

        return current_sample
def build_bbox_tensors(infos, max_length):
    num_bbox = min(max_length, len(infos))

    # After num_bbox, everything else should be zero
    coord_tensor = torch.zeros((max_length, 4), dtype=torch.float)
    width_tensor = torch.zeros(max_length, dtype=torch.float)
    height_tensor = torch.zeros(max_length, dtype=torch.float)
    bbox_types = ["xyxy"] * max_length

    infos = infos[:num_bbox]
    sample = Sample()

    for idx, info in enumerate(infos):
        bbox = info["bounding_box"]
        x = bbox.get("top_left_x", bbox["topLeftX"])
        y = bbox.get("top_left_y", bbox["topLeftY"])
        width = bbox["width"]
        height = bbox["height"]

        coord_tensor[idx][0] = x
        coord_tensor[idx][1] = y
        coord_tensor[idx][2] = x + width
        coord_tensor[idx][3] = y + height

        width_tensor[idx] = width
        height_tensor[idx] = height
    sample.coordinates = coord_tensor
    sample.width = width_tensor
    sample.height = height_tensor
    sample.bbox_types = bbox_types

    return sample
    def test_caption_bleu4(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../mmf/configs/datasets/coco/defaults.yaml",
        )
        config = load_yaml(os.path.abspath(path))
        captioning_config = config.dataset_config.coco
        caption_processor_config = captioning_config.processors.caption_processor
        vocab_path = os.path.join(os.path.abspath(__file__), "..", "..",
                                  "data", "vocab.txt")
        caption_processor_config.params.vocab.type = "random"
        caption_processor_config.params.vocab.vocab_file = os.path.abspath(
            vocab_path)
        caption_processor = CaptionProcessor(caption_processor_config.params)
        registry.register("coco_caption_processor", caption_processor)

        caption_bleu4 = metrics.CaptionBleu4Metric()
        expected = Sample()
        predicted = dict()

        # Test complete match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, :, 4] = 1.0

        self.assertEqual(
            caption_bleu4.calculate(expected, predicted).item(), 1.0)

        # Test partial match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, 0:5, 4] = 1.0
        predicted["scores"][:, 5:, 18] = 1.0

        self.assertAlmostEqual(
            caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
Beispiel #19
0
    def test_call(self):
        batch_collator = BatchCollator("vqa2", "train")
        sample_list = test_utils.build_random_sample_list()
        sample_list = batch_collator(sample_list)

        # Test already build sample list
        self.assertEqual(sample_list.dataset_name, "vqa2")
        self.assertEqual(sample_list.dataset_type, "train")

        sample = Sample()
        sample.a = torch.tensor([1, 2], dtype=torch.int)

        # Test list of samples
        sample_list = batch_collator([sample, sample])
        self.assertTrue(
            test_utils.compare_tensors(
                sample_list.a, torch.tensor([[1, 2], [1, 2]],
                                            dtype=torch.int)))

        # Test IterableDataset case
        sample_list = test_utils.build_random_sample_list()
        new_sample_list = batch_collator([sample_list])
        self.assertEqual(new_sample_list, sample_list)
    def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"
        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))
Beispiel #21
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Get the first image from the set of images returned from the image_db
        current_sample.image = self.image_db[idx]["images"][0]

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        plot = sample_info["plot"]
        if isinstance(plot, list):
            plot = plot[0]
        processed_sentence = self.text_processor({"text": plot})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_images is True:
            current_sample.image = self.image_db[idx]["images"][0]

        processed = self.answer_processor({"answers": sample_info["genres"]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        return current_sample
Beispiel #23
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_caption = self.text_processor(
            {"text": sample_info["captions"][0]})
        current_sample.text = processed_caption["text"]
        current_sample.caption_len = torch.tensor(len(
            processed_caption["text"]),
                                                  dtype=torch.int)

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = torch.tensor(sample_info["image_id"],
                                                   dtype=torch.int)
        else:
            current_sample.image_id = sample_info["image_id"]

        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)

        current_sample.answers = torch.stack([processed_caption["text"]])

        return current_sample
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        plot = sample_info["plot"]
        if isinstance(plot, list):
            plot = plot[0]
        processed_sentence = self.text_processor({"text": plot})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)

        processed = self.answer_processor({"answers": sample_info["genres"]})
        current_sample.answers = processed["answers"]
        current_sample.targets = processed["answers_scores"]

        return current_sample
    def _load_regions(self, idx, object_map, relationship_map):
        if self._return_scene_graph is None:
            return None, None

        image_info = self._get_image_info(idx)
        image_height = image_info["height"]
        image_width = image_info["width"]
        region_map = {}
        regions = []

        for region in image_info["regions"]:
            for synset in region["synsets"]:
                synset["entity_name"] = self.name_processor(
                    {"tokens": [synset["entity_name"]]})["text"]
                synset["synset_name"] = self.synset_processor(
                    {"tokens": [synset["synset_name"]]})["text"]

            region["height"] /= image_height
            region["width"] /= image_width
            region["y"] /= image_height
            region["x"] /= image_width

            relationships = []
            objects = []

            for relationship_idx in region["relationships"]:
                relationships.append(relationship_map[relationship_idx])

            for object_idx in region["objects"]:
                objects.append(object_map[object_idx])

            region["relationships"] = relationships
            region["objects"] = objects
            region["phrase"] = self.text_processor({"text":
                                                    region["phrase"]})["text"]

            region = Sample(region)
            region_map[region["region_id"]] = region
            regions.append(region)

        regions = SampleList(regions)
        return regions, region_map
def compare_torchscript_transformer_models(model, vocab_size):
    test_sample = Sample()
    test_sample.input_ids = torch.randint(low=0, high=vocab_size,
                                          size=(128, )).long()
    test_sample.input_mask = torch.ones(128).long()
    test_sample.segment_ids = torch.zeros(128).long()
    test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
    test_sample.image = torch.rand((3, 300, 300)).float()
    test_sample_list = SampleList([test_sample])

    with torch.no_grad():
        model_output = model(test_sample_list)

    script_model = torch.jit.script(model)
    with torch.no_grad():
        script_output = script_model(test_sample_list)

    return torch.equal(model_output["scores"], script_output["scores"])
Beispiel #27
0
    def __getitem__(self, idx: int) -> Type[Sample]:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_question = self.text_processor(
            {"text": sample_info["question"]})
        current_sample.update(processed_question)
        current_sample.id = torch.tensor(int(sample_info["question_id"]),
                                         dtype=torch.int)

        image_path = self.get_image_path(sample_info["image_id"],
                                         sample_info["coco_split"])
        current_sample.image = self.image_db.from_path(image_path)["images"][0]

        if "answers" in sample_info:
            answers = self.answer_processor(
                {"answers": sample_info["answers"]})
            current_sample.targets = answers["answers_scores"]

        return current_sample
    def _load_relationships(self, idx, object_map):
        if self._return_relationships is None and self._return_scene_graph is None:
            return None, None

        image_info = self._get_image_info(idx)
        relationship_map = {}
        relationships = []

        for relationship in image_info["relationships"]:
            relationship["synsets"] = self.synset_processor(
                {"tokens": relationship["synsets"]})["text"]
            relationship["predicate"] = self.predicate_processor(
                {"tokens": relationship["predicate"]})["text"]
            relationship["object"] = object_map[relationship["object_id"]]
            relationship["subject"] = object_map[relationship["subject_id"]]

            relationship = Sample(relationship)
            relationship_map[relationship["relationship_id"]] = relationship
            relationships.append(relationship)

        relationships = SampleList(relationships)
        return relationships, relationship_map
    def test_finetune_model(self):
        self.finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255,
                                              size=(128, )).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = self.finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(self.finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
    def classify(self, image: ImageType, text: str):
        """Classifies a given image and text in it into Hateful/Non-Hateful.
        Image can be a url or a local path or you can directly pass a PIL.Image.Image
        object. Text needs to be a sentence containing all text in the image.

            >>> from VisualBERT.mmf.models.mmbt import MMBT
            >>> model = MMBT.from_pretrained("mmbt.hateful_memes.images")
            >>> model.classify("some_url", "some_text")
            {"label": 0, "confidence": 0.56}

        Args:
            image (ImageType): Image to be classified
            text (str): Text in the image

        Returns:
            bool: Whether image is hateful (1) or non hateful (0)
        """
        if isinstance(image, str):
            if image.startswith("http"):
                temp_file = tempfile.NamedTemporaryFile()
                download(image,
                         *os.path.split(temp_file.name),
                         disable_tqdm=True)
                image = tv_helpers.default_loader(temp_file.name)
                temp_file.close()
            else:
                image = tv_helpers.default_loader(image)

        text = self.processor_dict["text_processor"]({"text": text})
        image = self.processor_dict["image_processor"](image)

        sample = Sample()
        sample.text = text["text"]
        if "input_ids" in text:
            sample.update(text)

        sample.image = image
        sample_list = SampleList([sample])
        device = next(self.model.parameters()).device
        sample_list = sample_list.to(device)

        output = self.model(sample_list)
        scores = nn.functional.softmax(output["scores"], dim=1)
        confidence, label = torch.max(scores, dim=1)

        return {"label": label.item(), "confidence": confidence.item()}