Пример #1
0
    def test_mmf_loss(self):
        get_loss_class_mock = MagicMock(side_effect=build_loss_side_effect())
        registry.get_loss_class = get_loss_class_mock
        # Test if MMFLoss accepts empty parameters
        self.assertRaises(ValueError, losses.MMFLoss)
        self.assertTrue(
            losses.MMFLoss({
                "type": "cross_entropy"
            }).name, "cross_entropy")
        self.assertTrue(losses.MMFLoss("cross_entropy").name, "cross_entropy")
        self.assertRaises(AssertionError, losses.MMFLoss, [])
        # Multi requires dict
        self.assertRaises(AssertionError, losses.MMFLoss, "multi")

        cross_entropy = losses.MMFLoss("cross_entropy")
        cross_entropy_from_dict = losses.MMFLoss({"type": "cross_entropy"})
        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"

        output = cross_entropy(sample_list, {})
        output_from_dict = cross_entropy_from_dict(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        get_loss_class_mock.side_effect = build_loss_side_effect(1.0)
        output = cross_entropy(sample_list, {})

        self.assertEqual(output, {"val/vqa2/cross_entropy": torch.tensor(1.0)})
        self.assertEqual(output_from_dict, output)

        self.assertTrue(get_loss_class_mock.called)
        self.assertEqual(get_loss_class_mock.call_count, 5)
Пример #2
0
    def test_stacked_feature_preprocessing(self):
        self._text_modality_config.key = "body"
        second_text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="ocr",
            embedding_dim=756,
            position_dim=128,
            segment_id=2,
            encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            second_text_modality_config,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        # In stacked case, input_ids should represent all texts
        sample_list.input_ids = torch.randint(0, 512, (2, 2, 128))
        sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 2, 128))
        lm_labels_sum = sample_list.lm_label_ids.sum().item()

        transformer_input = mmft.preprocess_sample(sample_list)
        self._compare_processed_for_multimodality(transformer_input,
                                                  lm_labels_sum)
Пример #3
0
    def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"

        test_sample_list = test_sample_list.to(get_current_device())
        cnn_lstm = cnn_lstm.to(get_current_device())

        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))
Пример #4
0
    def predict(self, url, text):
        with torch.no_grad():
            detectron_features = self.get_detectron_features(url)

            sample = Sample()

            processed_text = self.text_processor({"text": text})
            #sample.text = processed_text["text"]
            sample.text_len = len(processed_text["tokens"])

            encoded_input = tokenizer(text, return_tensors='pt')
            sample.input_ids = encoded_input.input_ids
            sample.input_mask = encoded_input.attention_mask
            sample.segment_ids = encoded_input.token_type_ids

            sample.image_feature_0 = detectron_features
            sample.image_info_0 = Sample(
                {"max_features": torch.tensor(100, dtype=torch.long)})

            sample_list = SampleList([sample])
            sample_list = sample_list.to("cuda")

            output = self.visual_bert(sample_list)

        gc.collect()
        torch.cuda.empty_cache()

        return output
Пример #5
0
    def __call__(self, image_tensor, text_input=None):
        ''' 
        Allow model to receive both multi-inputs and single image-inputs // Bojia Mao
        '''
        text = self.processor_dict["text_processor"]({"text": self.text})

        sample = Sample()

        if text_input == None:
            sample.text = text["text"]
        else:
            self.__text = text_input
            sample.text = text_input

        if "input_ids" in text:
            sample.update(text)

        sample.image = image_tensor
        sample_list = SampleList([sample])
        sample_list = sample_list.to(
            torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

        output = self.model(sample_list)
        scores = nn.functional.softmax(output["scores"], dim=1)

        return scores
Пример #6
0
    def _build_report(self):
        tensor_a = torch.tensor([[1, 2, 3, 4], [2, 3, 4, 5]])
        sample_list = SampleList()
        sample_list.add_field("a", tensor_a)
        model_output = {"scores": torch.rand(2, 2)}

        report = Report(sample_list, model_output)
        return report
Пример #7
0
    def classify(self, image: ImageType, text: str, image_tensor = None, zero_image=False, zero_text=False):
        """Classifies a given image and text in it into Hateful/Non-Hateful.
        Image can be a url or a local path or you can directly pass a PIL.Image.Image
        object. Text needs to be a sentence containing all text in the image.

        Args:
            image (ImageType): Image to be classified
            text (str): Text in the image
            zero_image: zero out the image features when classifying
            zero_text: zero out the text features when classifying
            return_type: either "prob" or "logits"

        Returns:
            {"label": 0, "confidence": 0.56}
        """
        sample = Sample()

        if image_tensor != None:
            sample.image = image_tensor
        else:
            


            if isinstance(image, str):
                if image.startswith("http"):
                    temp_file = tempfile.NamedTemporaryFile()
                    download(image, *os.path.split(temp_file.name), disable_tqdm=True)
                    image = tv_helpers.default_loader(temp_file.name)
                    temp_file.close()
                else:
                    image = tv_helpers.default_loader(image)

        
            image = self.processor_dict["image_processor"](image)
            sample.image = image

        text = self.processor_dict["text_processor"]({"text": text})


        sample.text = text["text"]
        if "input_ids" in text:
            sample.update(text)

        sample_list = SampleList([sample])
        device = next(self.model.parameters()).device
        sample_list = sample_list.to(device)
        output = self.model(sample_list, zero_image=zero_image, zero_text=zero_text)
        scores = nn.functional.softmax(output["scores"], dim=1)

        if image_tensor != None:
            return scores

        confidence, label = torch.max(scores, dim=1)

        return {"label": label.item(), "confidence": confidence.item()}
Пример #8
0
    def test_mmf_dict_loss(self):
        mse_mae_loss = losses.MMFLoss("mse_mae")
        torch.manual_seed(1234)
        random_tensor = torch.rand((1, 768))

        sample_list = SampleList()
        sample_list.dataset_type = "val"
        sample_list.dataset_name = "vqa2"
        sample_list["targets"] = random_tensor
        model_output = {"scores": random_tensor}

        output = mse_mae_loss(sample_list, model_output)

        self.assertEqual(output["val/vqa2/mse_mae/mse"].item(), 0.0)
        self.assertEqual(output["val/vqa2/mse_mae/mae"].item(), 0.0)
Пример #9
0
    def test_preprocessing_with_resnet_encoder(self):
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=2048,
            position_dim=1,
            segment_id=0,
            encoder=ImageEncoderFactory.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 3, 224, 224)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)

        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
Пример #10
0
    def _load_objects(self, idx):
        image_info = self._get_image_info(idx)
        image_height = image_info["height"]
        image_width = image_info["width"]
        object_map = {}
        objects = []

        for obj in image_info["objects"]:
            obj["synsets"] = self.synset_processor({"tokens":
                                                    obj["synsets"]})["text"]
            obj["names"] = self.name_processor({"tokens":
                                                obj["names"]})["text"]
            obj["height"] = obj["h"] / image_height
            obj.pop("h")
            obj["width"] = obj["w"] / image_width
            obj.pop("w")
            obj["y"] /= image_height
            obj["x"] /= image_width
            obj["attributes"] = self.attribute_processor(
                {"tokens": obj["attributes"]})["text"]
            obj = Sample(obj)
            object_map[obj["object_id"]] = obj
            objects.append(obj)
        objects = SampleList(objects)

        return objects, object_map
Пример #11
0
    def __init__(self,
                 batch: SampleList = None,
                 model_output: Dict[str, Any] = None,
                 *args):
        super().__init__(self)
        if batch is None:
            return
        if model_output is None:
            model_output = {}
        if self._check_and_load_tuple(batch):
            return

        all_args = [batch, model_output] + [*args]
        for idx, arg in enumerate(all_args):
            if not isinstance(arg, collections.abc.Mapping):
                raise TypeError("Argument {:d}, {} must be of instance of "
                                "collections.abc.Mapping".format(idx, arg))

        self.batch_size = batch.get_batch_size()
        self.warning_string = ("Updating forward report with key {}"
                               "{}, but it already exists in {}. "
                               "Please consider using a different key, "
                               "as this can cause issues during loss and "
                               "metric calculations.")

        for idx, arg in enumerate(all_args):
            for key, item in arg.items():
                if key in self and idx >= 2:
                    log = self.warning_string.format(
                        key, "", "in previous arguments to report")
                    warnings.warn(log)
                self[key] = item
Пример #12
0
    def __init__(self, num_train_data, max_updates, max_epochs):
        self.training_config = OmegaConf.create({
            "detect_anomaly": False,
            "evaluation_interval": 10000
        })
        if max_updates is not None:
            self.training_config["max_updates"] = max_updates
        if max_epochs is not None:
            self.training_config["max_epochs"] = max_epochs

        self.model = SimpleModel(1)
        if torch.cuda.is_available():
            self.model = self.model.cuda()

        self.dataset_loader = MagicMock()
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)
        self.optimizer = MagicMock()
        self.optimizer.step = MagicMock(return_value=None)
        self.optimizer.zero_grad = MagicMock(return_value=None)
        dataset = NumbersDataset(num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=1,
            shuffle=False,
            num_workers=1,
            drop_last=False,
        )
        self.on_batch_start = MagicMock(return_value=None)
        self.logistics_callback = MagicMock(return_value=None)
        self.logistics_callback.log_interval = MagicMock(return_value=None)
        self.on_batch_end = MagicMock(return_value=None)
        self.meter = MagicMock(return_value=None)
        self.after_training_loop = MagicMock(return_value=None)
Пример #13
0
    def test_nucleus_sampling(self):
        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES)

        model_config = self.config.model_config.butd
        model = TestDecoderModel(model_config, vocab)
        model.build()
        model.eval()

        sample = Sample()
        sample.dataset_name = "coco"
        sample.dataset_type = "test"
        sample.image_feature_0 = torch.randn(100, 2048)
        sample.answers = torch.zeros((5, 10), dtype=torch.long)
        sample_list = SampleList([sample])

        tokens = model(sample_list)["captions"]

        # these are expected tokens for sum_threshold = 0.5
        expected_tokens = [
            1.0,
            29.0,
            11.0,
            11.0,
            39.0,
            10.0,
            31.0,
            4.0,
            19.0,
            39.0,
            2.0,
        ]

        self.assertEqual(tokens[0].tolist(), expected_tokens)
Пример #14
0
    def test_beam_search(self):
        vocab = text_utils.VocabFromText(self.VOCAB_EXAMPLE_SENTENCES)
        model_config = self.config.model_config.butd
        model = TestDecoderModel(model_config, vocab)
        model.build()
        model.eval()

        expected_tokens = {
            1: [1.0, 23.0, 1.0, 24.0, 29.0, 37.0, 40.0, 17.0, 29.0, 2.0],
            2: [1.0, 0.0, 8.0, 1.0, 28.0, 25.0, 2.0],
            8: [1.0, 34.0, 1.0, 13.0, 1.0, 2.0],
            16: [1.0, 25.0, 18.0, 2.0],
        }

        for batch_size in [1, 2, 8, 16]:
            samples = []
            for _ in range(batch_size):
                sample = Sample()
                sample.dataset_name = "coco"
                sample.dataset_type = "test"
                sample.image_feature_0 = torch.randn(100, 2048)
                sample.answers = torch.zeros((5, 10), dtype=torch.long)
                samples.append(sample)

            sample_list = SampleList(samples)
            tokens = model(sample_list)["captions"]
            self.assertEqual(np.trim_zeros(tokens[0].tolist()),
                             expected_tokens[batch_size])
Пример #15
0
    def mlm_forward(
        self,
        input_ids_masked: Tensor,
        lm_label_ids: Tensor,
        token_type_ids: Tensor,
        attention_mask: Tensor,
        img_feats: Tensor,
        position_ids: Optional[Tensor] = None,
    ) -> Dict[str, Tensor]:

        hidden_layers = self.bert(
            input_ids_masked,
            img_feats=img_feats,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        ).last_hidden_state

        mlm_labels = {}
        mlm_labels["text"] = lm_label_ids
        mlm_labels["image"] = torch.full(
            img_feats.shape[:2],
            fill_value=-1,
            dtype=torch.long,
            device=lm_label_ids.device,
        )
        mlm_labels["combined_labels"] = torch.cat(
            [mlm_labels["text"], mlm_labels["image"]], dim=-1)

        processed_sample_list = SampleList({"mlm_labels": mlm_labels})
        return self.mlm_head(
            hidden_layers,
            processed_sample_list=processed_sample_list)["losses"]
Пример #16
0
    def classify(self, image: ImageType, text: str):
        """Classifies a given image and text in it into Hateful/Non-Hateful.
        Image can be a url or a local path or you can directly pass a PIL.Image.Image
        object. Text needs to be a sentence containing all text in the image.

            >>> from mmf.models.mmbt import MMBT
            >>> model = MMBT.from_pretrained("mmbt.hateful_memes.images")
            >>> model.classify("some_url", "some_text")
            {"label": 0, "confidence": 0.56}

        Args:
            image (ImageType): Image to be classified
            text (str): Text in the image

        Returns:
            bool: Whether image is hateful (1) or non hateful (0)
        """
        if isinstance(image, str):
            if image.startswith("http"):
                temp_file = tempfile.NamedTemporaryFile()
                download(image,
                         *os.path.split(temp_file.name),
                         disable_tqdm=True)
                image = tv_helpers.default_loader(temp_file.name)
                temp_file.close()
            else:
                image = tv_helpers.default_loader(image)

        text = self.processor_dict["text_processor"]({"text": text})
        image = self.processor_dict["image_processor"](image)

        sample = Sample()
        sample.text = text["text"]
        if "input_ids" in text:
            sample.update(text)

        sample.image = image
        sample_list = SampleList([sample])
        device = next(self.model.parameters()).device
        sample_list = sample_list.to(device)

        output = self.model(sample_list)
        scores = nn.functional.softmax(output["scores"], dim=1)
        confidence, label = torch.max(scores, dim=1)

        return {"label": label.item(), "confidence": confidence.item()}
Пример #17
0
    def load_datasets(self):
        self.dataset_loader = MultiDataModuleNumbersTestObject(
            num_data=self.num_data, batch_size=self.config.training.batch_size)
        self.dataset_loader.seed_sampler = MagicMock(return_value=None)
        self.dataset_loader.prepare_batch = lambda x: SampleList(x)

        self.train_loader = self.dataset_loader.train_dataloader()
        self.val_loader = self.dataset_loader.val_dataloader()
        self.test_loader = self.dataset_loader.test_dataloader()
Пример #18
0
    def __call__(self, batch):
        # Create and return sample list with proper name
        # and type set if it is already not a sample list
        # (case of batched iterators)
        sample_list = batch
        if (
            # Check if batch is a list before checking batch[0]
            # or len as sometimes batch is already SampleList
            isinstance(batch, list)
            and len(batch) == 1
            and isinstance(batch[0], SampleList)
        ):
            sample_list = batch[0]
        elif not isinstance(batch, SampleList):
            sample_list = SampleList(batch)

        sample_list.dataset_name = self._dataset_name
        sample_list.dataset_type = self._dataset_type
        return sample_list
Пример #19
0
    def test_convert_batch_to_sample_list(self):
        # Test list conversion
        batch = [{"a": torch.tensor([1.0, 1.0])}, {"a": torch.tensor([2.0, 2.0])}]
        sample_list = convert_batch_to_sample_list(batch)
        expected_a = torch.tensor([[1.0, 1.0], [2.0, 2.0]])
        self.assertTrue(torch.equal(expected_a, sample_list.a))

        # Test single element list, samplelist
        sample_list = SampleList()
        sample_list.add_field("a", expected_a)
        parsed_sample_list = convert_batch_to_sample_list([sample_list])
        self.assertTrue(isinstance(parsed_sample_list, SampleList))
        self.assertTrue("a" in parsed_sample_list)
        self.assertTrue(torch.equal(expected_a, parsed_sample_list.a))

        # Test no tensor field
        batch = [{"a": [1]}, {"a": [2]}]
        sample_list = convert_batch_to_sample_list(batch)
        self.assertTrue(sample_list.a, [[1], [2]])
Пример #20
0
    def test_one_dim_feature_preprocessing(self):
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 256])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        masks = mmft._infer_masks(sample_list, input_ids)
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())

        mlm_labels = transformer_input["mlm_labels"]
        test_utils.compare_tensors(
            mlm_labels["combined_labels"],
            torch.full((2, 129), dtype=torch.long, fill_value=-1),
        )
Пример #21
0
    def test_finetune_model(self):
        self.finetune_model.eval()
        test_sample = Sample()
        test_sample.input_ids = torch.randint(low=0, high=30255,
                                              size=(128, )).long()
        test_sample.input_mask = torch.ones(128).long()
        test_sample.segment_ids = torch.zeros(128).long()
        test_sample.image = torch.rand((3, 300, 300)).float()
        test_sample_list = SampleList([test_sample.copy()])

        with torch.no_grad():
            model_output = self.finetune_model.model(test_sample_list)

        test_sample_list = SampleList([test_sample])
        script_model = torch.jit.script(self.finetune_model.model)
        with torch.no_grad():
            script_output = script_model(test_sample_list)

        self.assertTrue(
            torch.equal(model_output["scores"], script_output["scores"]))
Пример #22
0
def compare_torchscript_transformer_models(model, vocab_size):
    test_sample = Sample()
    test_sample.input_ids = torch.randint(low=0, high=vocab_size, size=(128,)).long()
    test_sample.input_mask = torch.ones(128).long()
    test_sample.segment_ids = torch.zeros(128).long()
    test_sample.image_feature_0 = torch.rand((1, 100, 2048)).float()
    test_sample.image = torch.rand((3, 300, 300)).float()
    test_sample_list = SampleList([test_sample])

    model = model.to(get_current_device())
    test_sample_list = test_sample_list.to(get_current_device())

    with torch.no_grad():
        model_output = model(test_sample_list)

    script_model = torch.jit.script(model)
    with torch.no_grad():
        script_output = script_model(test_sample_list)

    return torch.equal(model_output["scores"], script_output["scores"])
Пример #23
0
    def predict(self, Q, F, topk):
        with torch.no_grad():
            detectron_features = torch.from_numpy(F)
            #resnet_features = torch.from_numpy(R)
            
            processed_text = self.text_processor({"text": Q})
            sample = Sample(processed_text)
            #sample.text = processed_text["text"]
            sample.text_len = len(processed_text["tokens"])

            sample.image_feature_0 = detectron_features
            sample.image_info_0 = Sample({
              "max_features": torch.tensor(100, dtype=torch.long)
            })

            #sample.image_feature_1 = resnet_features
            #print('res: ', resnet_features.shape)
            sample_list = SampleList([sample])
            #print(type(sample_list))
            sample_list = sample_list.to("cuda")

            scores = self.bert_model(sample_list)["scores"]
            scores = torch.nn.functional.softmax(scores, dim=1)
            actual, indices = scores.topk(topk, dim=1)

            top_indices = indices[:topk]
            top_scores = actual[:topk]

            probs = []
            answers = []

            for idx, score in enumerate(top_scores[0]):
                probs.append(score.item())
                answers.append(
                    self.answer_processor.idx2word(top_indices[0][idx].item())
                )
    
        gc.collect()
        torch.cuda.empty_cache()

        return probs, answers
Пример #24
0
 def forward(self, prepared_batch):
     input_sample = SampleList(prepared_batch)
     batch = prepared_batch[DATA_ITEM_KEY]
     output = self.linear(batch)
     loss = torch.nn.MSELoss()(-1 * output, batch)
     return {
         "losses": {
             "loss": loss
         },
         "logits": output,
         "input_batch": input_sample
     }
Пример #25
0
 def forward(self, prepared_batch: Dict[str, Tensor]):
     input_sample = SampleList(prepared_batch)
     batch = prepared_batch[self.data_item_key]
     output = self.classifier(batch)
     loss = torch.nn.MSELoss()(-1 * output, batch)
     return {
         "losses": {
             "loss": loss
         },
         "logits": output,
         "input_batch": input_sample
     }
Пример #26
0
    def prepare_batch(self, batch):
        """
        Can be possibly overridden in your child class

        Prepare batch for passing to model. Whatever returned from here will
        be directly passed to model's forward function. Currently moves the batch to
        proper device.

        Args:
            batch (SampleList): sample list containing the currently loaded batch

        Returns:
            sample_list (SampleList): Returns a sample representing current
                batch loaded
        """
        # Should be a SampleList
        if not isinstance(batch, SampleList):
            # Try converting to SampleList
            batch = SampleList(batch)
        batch = batch.to(self._device)
        return batch
Пример #27
0
    def forward(self, image_path: str, text: dict, image_format: str = "path"):
        text_output = self.processor["text_processor"](text)
        if image_format == "path":
            img = np.array(Image.open(image_path))
        elif image_format == "url":
            img = np.array(
                Image.open(requests.get(image_path, stream=True).raw))
        img = torch.as_tensor(img)

        if self.model_items["config"].image_feature_encodings.type == "frcnn":
            max_detect = self.model_items[
                "config"].image_feature_encodings.params.max_detections
            image_preprocessed, sizes, scales_yx = self.processor[
                "image_processor"](img)
            image_output = self.feature_extractor(
                image_preprocessed,
                sizes=sizes,
                scales_yx=scales_yx,
                padding=None,
                max_detections=max_detect,
                return_tensors="pt",
            )
            image_output = image_output[0]
        else:
            image_preprocessed = self.processor["image_processor"](img)
            image_output = self.feature_extractor(image_preprocessed)

        sample = Sample(text_output)
        sample.image_feature_0 = image_output
        sample_list = SampleList([sample])
        sample_list = sample_list.to(get_current_device())
        self.model = self.model.to(get_current_device())
        output = self.model(sample_list)
        sample_list.id = [sample_list.input_ids[0][0]]
        report = Report(sample_list, output)
        answers = self.processor["output_processor"](report)
        answer = self.processor["answer_processor"].idx2word(
            answers[0]["answer"])

        return answer
Пример #28
0
    def _load_regions(self, idx):
        if self._return_region_descriptions is None:
            return None, None

        image_info = self._get_image_info(idx)
        # print("img_info", image_info)        # {"regions":[], "id": }
        image_height = image_info["height"]
        image_width = image_info["width"]
        region_map = {}
        region_cat = {}
        region_cat["cat_description"] = []
        regions = []

        for region in image_info["regions"]:
            region["height"] /= image_height
            region["width"] /= image_width
            region["y"] /= image_height
            region["x"] /= image_width
            region["phrase"] = self.text_processor({"text":
                                                    region["phrase"]})["text"]
            # region["region_id"]=torch.tensor(region["region_id"])
            # region["height"]=torch.tensor(region["height"])
            # region["width"]=torch.tensor(region["width"])
            # region["y"]=torch.tensor(region["y"])
            # region["x"]=torch.tensor(region["x"])
            region["phrase"] = region["phrase"].numpy()
            # print("region", region)
            # {'region_id': 3989716, 'width': 0.050666666666666665, 'height': 0.016, 'image_id': 2332304, 'phrase': tensor([66632, 44395, 46900, 66632, 49920,     0,     0,     0,     0,     0,
            # 0,     0,     0,     0,     0,     0,     0,     0,     0,     0]), 'y': 0.182, 'x': 0.56}
            region = Sample(region)
            # sampled_region: Sample([('region_id', 3989715), ('width', 0.05333333333333334), ('height', 0.018), ('image_id', 2332304), ('phrase', tensor([48867, 46900, 66632, 60502,     0,     0,     0,     0,     0,     0,
            # 0,     0,     0,     0,     0,     0,     0,     0,     0,     0])), ('y', 0.268), ('x', 0.6426666666666667)])

            # cat region.values
            # region_cat["cat_description"]=[region["region_id"], region["height"], region["width"], region["y"], region["x"]] # .append(region["phrase"])
            # # transform to tensor
            # region_cat["cat_description"]=torch.tensor(region_cat["cat_description"]) # ??? dtype
            # # cat phrase
            # region_cat["cat_description"]= torch.cat((region_cat["cat_description"], region["phrase"].float()))
            # region_cat = Sample(region_cat)
            # print("sampled_region_cat", region_cat)

            region_map[region["region_id"]] = region
            regions.append(region)

        # print("regions", regions)
        regions = SampleList(regions)
        regions["image_id"] = torch.tensor(regions["image_id"][0],
                                           dtype=torch.int32)
        regions["image_url"] = image_info["url"]
        # print("regions sample list", regions)
        return regions, region_map
Пример #29
0
    def predict(self, url, question):
        with torch.no_grad():
            detectron_features = self.get_detectron_features(url)
            resnet_features = self.get_resnet_features(url)

            sample = Sample()

            processed_text = self.text_processor({"text": question})
            sample.text = processed_text["text"]
            sample.text_len = len(processed_text["tokens"])

            sample.image_feature_0 = detectron_features
            sample.image_info_0 = Sample(
                {"max_features": torch.tensor(100, dtype=torch.long)})

            sample.image_feature_1 = resnet_features

            sample_list = SampleList([sample])
            sample_list = sample_list.to("cuda")

            scores = self.pythia_model(sample_list)["scores"]
            scores = torch.nn.functional.softmax(scores, dim=1)
            actual, indices = scores.topk(5, dim=1)

            top_indices = indices[0]
            top_scores = actual[0]

            probs = []
            answers = []

            for idx, score in enumerate(top_scores):
                probs.append(score.item())
                answers.append(
                    self.answer_processor.idx2word(top_indices[idx].item()))

        gc.collect()
        torch.cuda.empty_cache()

        return probs, answers
Пример #30
0
    def test_meter_update_from_report(self):
        meter = Meter()
        prepared_batch = SampleList(
            {"targets": torch.tensor([1, 2, 3, 4]), "dataset_type": "val"}
        )
        for idx in range(5):
            model_output = {
                "scores": torch.tensor([0, 1, 2, 3]),
                "losses": {"loss": float(idx)},
            }
            report = Report(prepared_batch, model_output)
            meter.update_from_report(report)

        self.assertEqual(meter.loss.global_avg, 2.0)
        self.assertEqual(meter.loss.avg, 2.0)