コード例 #1
0
    def load_item(self, idx):
        sample_info = self.imdb[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor(
            {"text": sample_info["sentence"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = "-".join(sample_info["identifier"].split("-")[:-1])
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}-img0.npy".format(identifier)
            features = self.features_db[idx]
            current_sample.img0 = Sample()
            current_sample.img0.update(features)

            sample_info["feature_path"] = "{}-img1.npy".format(identifier)
            features = self.features_db[idx]
            current_sample.img1 = Sample()
            current_sample.img1.update(features)

        is_correct = 1 if sample_info["label"] == "True" else 0
        current_sample.targets = torch.tensor(is_correct, dtype=torch.long)

        return current_sample
コード例 #2
0
ファイル: dataset.py プロジェクト: Yui010206/mmf
    def __getitem__(self, idx):

        record = self.video_meta[idx]
        sampled_frame_indices = self._sample_indices(record)

        video_path = record.path 
        video_start = record.start
        video_en = record.end
        video_level_label = record.label


        image_tmpl,indices = self._get_img_file_temp(record,sampled_frame_indices)

        ref_img_filename = reading_with_exception_handling(record, image_tmpl, int(indices[0]))
        temp_img_ = self._load_image(ref_img_filename)[0]
        ref_size = temp_img_.size

        tracklet_info = self._parse_kmot_tracklet(record.rois)
        filtered_boxes_with_track_by_indice = self._pick_mot_roi_by_indices(indices,tracklet_info,ref_size)
        filtered_mot_box_dict = self._mot_list_to_indice_dict(indices,filtered_boxes_with_track_by_indice)

        raw_gt_boxes_by_frames, box_labels_by_frames = self._get_gt_patch(indices,video_path,filtered_mot_box_dict)

        processed_patch_boxes,targets = self._generate_patch_input(raw_gt_boxes_by_frames,box_labels_by_frames,ref_size,image_tmpl)


        current_sample = Sample()
        current_sample.targets = targets
        current_sample.human_box = processed_patch_boxes

        return current_sample
コード例 #3
0
ファイル: test_utils.py プロジェクト: facebookresearch/mmf
 def __getitem__(self, idx: int) -> Sample:
     sample = Sample()
     sample[self.data_item_key] = torch.tensor(
         idx, dtype=torch.float32).unsqueeze(-1)
     if self.always_one:
         sample.targets = torch.tensor(0, dtype=torch.long)
     return sample
コード例 #4
0
ファイル: test_cnn_lstm.py プロジェクト: vishalbelsare/pythia
    def test_forward(self):
        model_config = self.config.model_config.cnn_lstm

        cnn_lstm = CNNLSTM(model_config)
        cnn_lstm.build()
        cnn_lstm.init_losses()

        self.assertTrue(isinstance(cnn_lstm, torch.nn.Module))

        test_sample = Sample()
        test_sample.text = torch.randint(1, 79, (10, ), dtype=torch.long)
        test_sample.image = torch.randn(3, 320, 480)
        test_sample.targets = torch.randn(32)

        test_sample_list = SampleList([test_sample])
        test_sample_list.dataset_type = "train"
        test_sample_list.dataset_name = "clevr"

        test_sample_list = test_sample_list.to(get_current_device())
        cnn_lstm = cnn_lstm.to(get_current_device())

        output = cnn_lstm(test_sample_list)

        scores = output["scores"]
        loss = output["losses"]["train/clevr/logit_bce"]

        np.testing.assert_almost_equal(loss.item(), 19.2635, decimal=4)
        self.assertEqual(scores.size(), torch.Size((1, 32)))
コード例 #5
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        text_processor_argument = {"text": sample_info["question_str"]}
        processed_question = self.text_processor(text_processor_argument)
        current_sample.text = processed_question["text"]
        if "input_ids" in processed_question:
            current_sample.update(processed_question)

        current_sample.question_id = torch.tensor(sample_info["question_id"],
                                                  dtype=torch.int)

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = torch.tensor(sample_info["image_id"],
                                                   dtype=torch.int)
        else:
            current_sample.image_id = sample_info["image_id"]

        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)

        # Depending on whether we are using soft copy this can add
        # dynamic answer space
        current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
コード例 #6
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        if self._dataset_type != "test":
            text_processor_argument = {"tokens": sample_info["caption_tokens"]}
            processed_caption = self.text_processor(text_processor_argument)
            current_sample.text = processed_caption["text"]
            current_sample.caption_id = torch.tensor(sample_info["caption_id"],
                                                     dtype=torch.int)
            current_sample.caption_len = torch.tensor(len(
                sample_info["caption_tokens"]),
                                                      dtype=torch.int)

        current_sample.image_id = object_to_byte_tensor(
            sample_info["image_id"])

        if self._use_features:
            features = self.features_db[idx]
            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        # Add reference captions to sample
        current_sample = self.add_reference_caption(sample_info,
                                                    current_sample)

        return current_sample
コード例 #7
0
ファイル: dataset.py プロジェクト: EXYNOS-999/DeepMeMes
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        if self._dataset_type != "test":
            text_processor_argument = {"tokens": sample_info["caption_tokens"]}
            processed_caption = self.text_processor(text_processor_argument)
            current_sample.text = processed_caption["text"]
            current_sample.caption_id = torch.tensor(sample_info["caption_id"],
                                                     dtype=torch.int)
            current_sample.caption_len = torch.tensor(len(
                sample_info["caption_tokens"]),
                                                      dtype=torch.int)

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = torch.tensor(sample_info["image_id"],
                                                   dtype=torch.int)
        else:
            current_sample.image_id = object_to_byte_tensor(
                sample_info["image_id"])

        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)

        # Add reference captions to sample
        current_sample = self.add_reference_caption(sample_info,
                                                    current_sample)

        return current_sample
コード例 #8
0
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)

        current_sample = Sample()

        processed_text = self.text_processor({"text": sample_info["text"]})
        current_sample.text = processed_text["text"]
        if "input_ids" in processed_text:
            current_sample.update(processed_text)

        current_sample.id = torch.tensor(int(sample_info["id"]),
                                         dtype=torch.int)

        # Instead of using idx directly here, use sample_info to fetch
        # the features as feature_path has been dynamically added
        features = self.features_db.get(sample_info)
        if hasattr(self, "transformer_bbox_processor"):
            features["image_info_0"] = self.transformer_bbox_processor(
                features["image_info_0"])
        current_sample.update(features)

        if "label" in sample_info:
            current_sample.targets = torch.tensor(sample_info["label"],
                                                  dtype=torch.long)

        return current_sample
コード例 #9
0
ファイル: dataset.py プロジェクト: zhang703652632/mmf
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        processed_sentence = self.text_processor({"text": sample_info["sentence2"]})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        if self._use_features is True:
            # Remove sentence id from end
            identifier = sample_info["Flikr30kID"].split(".")[0]
            # Load img0 and img1 features
            sample_info["feature_path"] = "{}.npy".format(identifier)
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"]
                )
            current_sample.update(features)
        else:
            image_path = sample_info["Flikr30kID"]
            current_sample.image = self.image_db.from_path(image_path)["images"][0]

        label = LABEL_TO_INT_MAPPING[sample_info["gold_label"]]
        current_sample.targets = torch.tensor(label, dtype=torch.long)

        return current_sample
コード例 #10
0
    def predict(self, url, text):
        with torch.no_grad():
            detectron_features = self.get_detectron_features(url)

            sample = Sample()

            processed_text = self.text_processor({"text": text})
            #sample.text = processed_text["text"]
            sample.text_len = len(processed_text["tokens"])

            encoded_input = tokenizer(text, return_tensors='pt')
            sample.input_ids = encoded_input.input_ids
            sample.input_mask = encoded_input.attention_mask
            sample.segment_ids = encoded_input.token_type_ids

            sample.image_feature_0 = detectron_features
            sample.image_info_0 = Sample(
                {"max_features": torch.tensor(100, dtype=torch.long)})

            sample_list = SampleList([sample])
            sample_list = sample_list.to("cuda")

            output = self.visual_bert(sample_list)

        gc.collect()
        torch.cuda.empty_cache()

        return output
コード例 #11
0
    def __getitem__(self, idx: int) -> Sample:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_caption = self.masked_token_processor({
            "text_a":
            sample_info["caption"],
            "text_b":
            "",
            "is_correct":
            True
        })
        current_sample.update(processed_caption)
        current_sample.image_id = sample_info["image_id"]
        current_sample.feature_path = sample_info["feature_path"]

        # Get the image features
        if self._use_features:
            features = self.features_db[idx]
            image_info_0 = features["image_info_0"]
            if image_info_0 and "image_id" in image_info_0.keys():
                image_info_0["feature_path"] = image_info_0["image_id"]
                image_info_0.pop("image_id")
            current_sample.update(features)
        elif self._use_images:
            image_id = sample_info["image_id"]
            dataset = sample_info["dataset_id"]
            if "mscoco" in dataset:
                image_id = image_id.rjust(12, "0")

            assert (len(self.image_db.from_path(image_id)["images"]) !=
                    0), f"image id: {image_id} not found"
            current_sample.image = self.image_db.from_path(
                image_id)["images"][0]

        return current_sample
コード例 #12
0
ファイル: dataset.py プロジェクト: vishalbelsare/pythia
    def __getitem__(self, idx: int) -> Type[Sample]:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        if "question_tokens" in sample_info:
            text_processor_argument = {
                "tokens": sample_info["question_tokens"],
                "text": sample_info["question_str"],
            }
        else:
            text_processor_argument = {"text": sample_info["question"]}
        processed_question = self.text_processor(text_processor_argument)
        current_sample.update(processed_question)
        current_sample.id = torch.tensor(int(sample_info["question_id"]),
                                         dtype=torch.int)

        if self._use_features:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"])
            current_sample.update(features)
        else:
            image_path = sample_info["image_name"] + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
コード例 #13
0
ファイル: builder.py プロジェクト: zhang703652632/mmf
    def __getitem__(self, idx):
        annotation = self.dataset[idx]
        current_sample = Sample()
        text_processor_input = {
            "text_a": annotation[self.DATASET_KEY_MAP["text_a"][self.dataset_name]]
        }

        text_b = annotation.get(self.DATASET_KEY_MAP["text_b"][self.dataset_name], None)
        if text_b is not None:
            text_processor_input["text_b"] = text_b

        current_sample.update(self.text_processor(text_processor_input))
        current_sample.targets = torch.tensor(annotation["label"], dtype=torch.long)
        return current_sample
コード例 #14
0
    def __getitem__(self, idx):
        if len(self.video_clips) == 0:
            self.load_df()
        video, audio, info = self.video_clips.get_clip(idx)
        text = self.text_list[idx]
        actual_idx = self.ids_list[idx]
        label = [
            self.class_to_idx[class_name] for class_name in self.labels[idx]
        ]
        one_hot_label = torch.zeros(len(self.class_to_idx))
        one_hot_label[label] = 1

        if self.video_processor is not None:
            video = self.video_processor(video)

        if self.audio_processor is not None:
            audio = self.audio_processor(audio)

        sample = Sample()
        sample.id = object_to_byte_tensor(actual_idx)
        sample.video = video
        sample.audio = audio
        sample.update(self.text_processor({"text": text}))
        sample.targets = one_hot_label
        return sample
コード例 #15
0
    def __getitem__(self, idx: int) -> Sample:
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        processed_caption = self.masked_token_processor({
            "text_a":
            sample_info["caption"],
            "text_b":
            "",
            "is_correct":
            True
        })
        current_sample.update(processed_caption)
        current_sample.image_id = sample_info["image_id"]
        current_sample.feature_path = sample_info["feature_path"]

        # Get the image features
        if self._use_features:
            features = self.features_db[idx]
            image_info_0 = features["image_info_0"]
            if image_info_0 and "image_id" in image_info_0.keys():
                image_info_0["feature_path"] = image_info_0["image_id"]
                image_info_0.pop("image_id")
            current_sample.update(features)

        return current_sample
コード例 #16
0
ファイル: dataset.py プロジェクト: flashjames/refundr-mmf
    def _load_objects(self, idx):
        image_info = self._get_image_info(idx)
        image_height = image_info["height"]
        image_width = image_info["width"]
        object_map = {}
        objects = []

        for obj in image_info["objects"]:
            obj["synsets"] = self.synset_processor({"tokens":
                                                    obj["synsets"]})["text"]
            obj["names"] = self.name_processor({"tokens":
                                                obj["names"]})["text"]
            obj["height"] = obj["h"] / image_height
            obj.pop("h")
            obj["width"] = obj["w"] / image_width
            obj.pop("w")
            obj["y"] /= image_height
            obj["x"] /= image_width
            obj["attributes"] = self.attribute_processor(
                {"tokens": obj["attributes"]})["text"]
            obj = Sample(obj)
            object_map[obj["object_id"]] = obj
            objects.append(obj)
        objects = SampleList(objects)

        return objects, object_map
コード例 #17
0
    def add_ocr_details(self, sample_info, sample):
        if self.use_ocr:
            # Preprocess OCR tokens
            ocr_tokens = [
                self.ocr_token_processor({"text": token})["text"]
                for token in sample_info["ocr_tokens"]
            ]
            # Get embeddings for tokens
            context = self.context_processor({"tokens": ocr_tokens})
            sample.context = context["text"]
            sample.context_tokens = context["tokens"]
            sample.context_feature_0 = context["text"]
            sample.context_info_0 = Sample()
            sample.context_info_0.max_features = context["length"]

            order_vectors = torch.eye(len(sample.context_tokens))
            order_vectors[context["length"] :] = 0
            sample.order_vectors = order_vectors

        if self.use_ocr_info and "ocr_info" in sample_info:
            sample.ocr_bbox = self.bbox_processor({"info": sample_info["ocr_info"]})[
                "bbox"
            ]

        return sample
コード例 #18
0
ファイル: test_metrics.py プロジェクト: zhang703652632/mmf
    def _test_multilabel_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor(
            [[0, 1, 1], [1, 0, 1], [1, 0, 1], [0, 0, 1]], dtype=torch.float)
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149, 0.3491],
                [-0.8391, 0.6797, -0.3410],
                [-0.7235, 0.7220, 0.9104],
                [0.9043, 0.3078, -0.4210],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(
            metric.calculate(sample, predicted).item(), value, 4)
コード例 #19
0
    def test_sample_working(self):
        initial = Sample()
        initial.x = 1
        initial["y"] = 2
        # Assert setter and getter
        self.assertEqual(initial.x, 1)
        self.assertEqual(initial["x"], 1)
        self.assertEqual(initial.y, 2)
        self.assertEqual(initial["y"], 2)

        update_dict = {"a": 3, "b": {"c": 4}}

        initial.update(update_dict)
        self.assertEqual(initial.a, 3)
        self.assertEqual(initial["a"], 3)
        self.assertEqual(initial.b.c, 4)
        self.assertEqual(initial["b"].c, 4)
コード例 #20
0
    def __call__(self, image_tensor, text_input=None):
        ''' 
        Allow model to receive both multi-inputs and single image-inputs // Bojia Mao
        '''
        text = self.processor_dict["text_processor"]({"text": self.text})

        sample = Sample()

        if text_input == None:
            sample.text = text["text"]
        else:
            self.__text = text_input
            sample.text = text_input

        if "input_ids" in text:
            sample.update(text)

        sample.image = image_tensor
        sample_list = SampleList([sample])
        sample_list = sample_list.to(
            torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

        output = self.model(sample_list)
        scores = nn.functional.softmax(output["scores"], dim=1)

        return scores
コード例 #21
0
ファイル: dataset.py プロジェクト: zhangshengHust/mmf
    def __getitem__(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        # breaking change from VQA2Dataset: load question_id
        current_sample.question_id = torch.tensor(
            sample_info["question_id"], dtype=torch.int
        )

        if isinstance(sample_info["image_id"], int):
            current_sample.image_id = str(sample_info["image_id"])
        else:
            current_sample.image_id = sample_info["image_id"]
        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)

        current_sample = self.add_sample_details(sample_info, current_sample)
        current_sample = self.add_answer_info(sample_info, current_sample)

        # only the 'max_features' key is needed
        # pop other keys to minimize data loading overhead
        if hasattr(current_sample, "image_info_0"):
            for k in list(current_sample.image_info_0):
                if k != "max_features":
                    current_sample.image_info_0.pop(k)
        if hasattr(current_sample, "image_info_1"):
            for k in list(current_sample.image_info_1):
                if k != "max_features":
                    current_sample.image_info_1.pop(k)

        return current_sample
コード例 #22
0
ファイル: dataset.py プロジェクト: flashjames/refundr-mmf
    def __getitem__(self, idx):
        
        sample_info = self.annotation_db[idx]
        current_sample = Sample()
        #plot = sample_info["plotr"]
        #if isinstance(plot, list):
        #   plot = plot[0]
        raw_text = sample_info["name"]
        processed_sentence = self.text_processor({"text": raw_text})

        current_sample.text = processed_sentence["text"]
        if "input_ids" in processed_sentence:
            current_sample.update(processed_sentence)

        """
        if self._use_features is True:
            features = self.features_db[idx]
            current_sample.update(features)
        """
        current_sample.answers, current_sample.targets = self.labels(sample_info, idx)
        #import pdb;pdb.set_trace()
        print(idx)
        current_sample.id = sample_info["idx"]
        
        #import pdb;pdb.set_trace()
        
        return current_sample
コード例 #23
0
ファイル: masked_dataset.py プロジェクト: rationalmale/mmf
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        current_sample = Sample()

        if self._use_features:
            features = self.features_db[idx]
            if hasattr(self, "transformer_bbox_processor"):
                features["image_info_0"] = self.transformer_bbox_processor(
                    features["image_info_0"]
                )

            if self.config.get("use_image_feature_masks", False):
                current_sample.update(
                    {
                        "image_labels": self.masked_region_processor(
                            features["image_feature_0"]
                        )
                    }
                )

            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(image_path)["images"][0]

        current_sample = self._add_masked_question(sample_info, current_sample)
        if self._add_answer:
            current_sample = self.add_answer_info(sample_info, current_sample)
        return current_sample
コード例 #24
0
ファイル: bert_processors.py プロジェクト: hellcodes/mmf
    def __call__(self, item):
        texts = item["text"]
        if not isinstance(texts, list):
            texts = [texts]

        processed = []
        for idx, text in enumerate(texts):
            sample = Sample()
            processed_text = super().__call__({"text": text})
            sample.update(processed_text)
            sample.segment_ids.fill_(idx)
            processed.append(sample)
        # Use SampleList to convert list of tensors to stacked tensors
        processed = SampleList(processed)
        processed.input_ids = processed.input_ids.view(-1)
        processed.input_mask = processed.input_mask.view(-1)
        processed.segment_ids = processed.segment_ids.view(-1)
        return processed.to_dict()
コード例 #25
0
ファイル: test_metrics.py プロジェクト: zhang703652632/mmf
    def _test_retrieval_recall_at_k_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        torch.manual_seed(1234)
        predicted["targets"] = torch.rand((10, 4))
        predicted["scores"] = torch.rand((10, 4))

        self.assertAlmostEqual(float(metric.calculate(sample, predicted)),
                               value)
コード例 #26
0
ファイル: test_heads.py プロジェクト: facebookresearch/mmf
    def setUp(self):
        bs = 8
        num_feat = 64
        feat_dim = 768
        self.sequence_input = torch.ones(size=(bs, num_feat, feat_dim),
                                         dtype=torch.float)
        contrastive_labels = torch.randint(3, (bs, ))

        self.processed_sample_list = Sample()
        self.processed_sample_list["contrastive_labels"] = contrastive_labels
コード例 #27
0
ファイル: handler.py プロジェクト: nskool/serve
        def create_sample(video_transfomred, audio_transfomred, text_tensor,
                          video_label):

            label = [self.class_to_idx[l] for l in video_label]

            one_hot_label = torch.zeros(len(self.class_to_idx))
            one_hot_label[label] = 1

            current_sample = Sample()
            current_sample.video = video_transfomred
            current_sample.audio = audio_transfomred
            current_sample.update(text_tensor)
            current_sample.targets = one_hot_label
            current_sample.dataset_type = 'test'
            current_sample.dataset_name = 'charades'
            return SampleList([current_sample]).to(self.device)
コード例 #28
0
    def __call__(self, item: Dict[str, Any]):
        texts = item["text"]
        if not isinstance(texts, list):
            texts = [texts]

        processed = []
        for idx, text in enumerate(texts):
            sample = Sample()
            processed_text = self.tokenizer({"text": text})
            sample.update(processed_text)
            sample.segment_ids.fill_(idx)
            processed.append(sample)
        # Use SampleList to convert list of tensors to stacked tensors
        processed = SampleList(processed)
        if self.fusion_strategy == "concat":
            processed.input_ids = processed.input_ids.view(-1)
            processed.input_mask = processed.input_mask.view(-1)
            processed.segment_ids = processed.segment_ids.view(-1)
            processed.lm_label_ids = processed.lm_label_ids.view(-1)
        return processed.to_dict()
コード例 #29
0
ファイル: test_metrics.py プロジェクト: weexiaolong/mmf
    def _test_binary_metric(self, metric, value):
        sample = Sample()
        predicted = dict()

        sample.targets = torch.tensor(
            [[0, 1], [1, 0], [1, 0], [0, 1]], dtype=torch.float
        )
        predicted["scores"] = torch.tensor(
            [
                [-0.9332, 0.8149],
                [-0.8391, 0.6797],
                [-0.7235, 0.7220],
                [-0.9043, 0.3078],
            ],
            dtype=torch.float,
        )
        self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)

        sample.targets = torch.tensor([1, 0, 0, 1], dtype=torch.long)
        self.assertAlmostEqual(metric.calculate(sample, predicted).item(), value, 4)
コード例 #30
0
    def load_item(self, idx):
        sample_info = self.imdb[idx]
        current_sample = Sample()

        if self._use_features is True:
            features = self.features_db[idx]
            image_labels = []

            for i in range(features["image_feature_0"].shape[0]):
                prob = random.random()
                # mask token with 15% probability
                if prob < 0.15:
                    prob /= 0.15

                    if prob < 0.9:
                        features["image_feature_0"][i] = 0
                    image_labels.append(1)
                else:
                    # no masking token (will be ignored by loss function later)
                    image_labels.append(-1)
            item = {}

            if self.config.get("use_image_feature_masks", False):
                item["image_labels"] = image_labels
            current_sample.update(item)
            current_sample.update(features)

        current_sample = self._add_masked_caption(sample_info, current_sample)
        return current_sample