コード例 #1
0
    def add_answer_info(self, sample_info, sample):
        sample_has_caption = "caption_str" in sample_info
        if sample_has_caption:
            sample_info["answers"] = [sample_info["caption_str"]]

        sample = super().add_answer_info(sample_info, sample)

        if sample_has_caption:
            sample.caption_str = object_to_byte_tensor(
                sample_info["caption_str"])
            sample.ref_strs = object_to_byte_tensor(
                sample_info["reference_strs"])
            sample.pop("answers")

        return sample
コード例 #2
0
    def load_item(self, idx):
        sample_info = self.annotation_db[idx]
        sample_info = self.preprocess_sample_info(sample_info)
        current_sample = Sample()

        if self._dataset_type != "test":
            text_processor_argument = {"tokens": sample_info["caption_tokens"]}
            processed_caption = self.text_processor(text_processor_argument)
            current_sample.text = processed_caption["text"]
            current_sample.caption_id = torch.tensor(sample_info["caption_id"],
                                                     dtype=torch.int)
            current_sample.caption_len = torch.tensor(len(
                sample_info["caption_tokens"]),
                                                      dtype=torch.int)

        current_sample.image_id = object_to_byte_tensor(
            sample_info["image_id"])

        if self._use_features:
            features = self.features_db[idx]
            current_sample.update(features)
        else:
            image_path = str(sample_info["image_name"]) + ".jpg"
            current_sample.image = self.image_db.from_path(
                image_path)["images"][0]

        # Add reference captions to sample
        current_sample = self.add_reference_caption(sample_info,
                                                    current_sample)

        return current_sample
コード例 #3
0
    def add_answer_info(self, sample_info, sample):
        # Load real answers from sample_info
        answers = sample_info.get("answers", [])
        answer_processor_arg = {"answers": answers}

        answer_processor_arg["tokens"] = sample.pop("ocr_tokens", [])

        processed_answers = self.answer_processor(answer_processor_arg)

        assert not self.config.fast_read, (
            "In TextVQADataset, online OCR sampling is incompatible "
            "with fast_read, so fast_read is currently not supported.")

        sample.update(processed_answers)
        sample.answers = object_to_byte_tensor(answers)

        if "answers_scores" in sample:
            sample.targets = sample.pop("answers_scores")

        return sample
コード例 #4
0
    def add_sample_details(self, sample_info, sample):
        sample.image_id = object_to_byte_tensor(sample.image_id)

        # 1. Load text (question words)
        question_str = (sample_info["question"] if "question" in sample_info
                        else sample_info["question_str"])
        text_processor_args = {"text": question_str}

        if "question_tokens" in sample_info:
            text_processor_args["tokens"] = sample_info["question_tokens"]

        processed_question = self.text_processor(text_processor_args)

        if "input_ids" in processed_question:
            sample.text = processed_question["input_ids"]
            sample.text_len = torch.tensor(len(processed_question["tokens"]),
                                           dtype=torch.long)
        else:
            # For GLoVe based processors
            sample.text = processed_question["text"]
            sample.text_len = processed_question["length"]

        # 2. Load object
        # object bounding box information
        if "obj_normalized_boxes" in sample_info and hasattr(
                self, "copy_processor"):
            sample.obj_bbox_coordinates = self.copy_processor(
                {"blob": sample_info["obj_normalized_boxes"]})["blob"]

        # 3. Load OCR
        if not self.use_ocr:
            # remove all OCRs from the sample
            # (i.e. make an empty OCR list)
            sample_info["ocr_tokens"] = []
            sample_info["ocr_info"] = []
            if "ocr_normalized_boxes" in sample_info:
                sample_info["ocr_normalized_boxes"] = np.zeros((0, 4),
                                                               np.float32)
            # clear OCR visual features
            if "image_feature_1" in sample:
                sample.image_feature_1 = torch.zeros_like(
                    sample.image_feature_1)
            return sample

        # Preprocess OCR tokens
        if hasattr(self, "ocr_token_processor"):
            ocr_tokens = [
                self.ocr_token_processor({"text": token})["text"]
                for token in sample_info["ocr_tokens"]
            ]
        else:
            ocr_tokens = sample_info["ocr_tokens"]
        # Get FastText embeddings for OCR tokens
        context = self.context_processor({"tokens": ocr_tokens})
        sample.context = context["text"]
        sample.ocr_tokens = context["tokens"]

        sample.context_tokens = object_to_byte_tensor(context["tokens"])
        sample.context_feature_0 = context["text"]
        sample.context_info_0 = Sample()
        sample.context_info_0.max_features = context["length"]

        # Get PHOC embeddings for OCR tokens
        if hasattr(self, "phoc_processor"):
            context_phoc = self.phoc_processor({"tokens": ocr_tokens})
            sample.context_feature_1 = context_phoc["text"]
            sample.context_info_1 = Sample()
            sample.context_info_1.max_features = context_phoc["length"]

        # OCR order vectors
        if self.config.get("use_order_vectors", False):
            order_vectors = np.eye(len(sample.ocr_tokens), dtype=np.float32)
            order_vectors = torch.from_numpy(order_vectors)
            order_vectors[context["length"]:] = 0
            sample.order_vectors = order_vectors

        # OCR bounding box information
        if "ocr_normalized_boxes" in sample_info and hasattr(
                self, "copy_processor"):
            # New imdb format: OCR bounding boxes are already pre-computed
            max_len = self.config.processors.answer_processor.params.max_length
            sample.ocr_bbox_coordinates = self.copy_processor(
                {"blob":
                 sample_info["ocr_normalized_boxes"]})["blob"][:max_len]
        elif self.use_ocr_info and "ocr_info" in sample_info:
            # Old imdb format: OCR bounding boxes are computed on-the-fly
            # from ocr_info
            sample.ocr_bbox_coordinates = self.bbox_processor(
                {"info": sample_info["ocr_info"]})["bbox"].coordinates

        return sample
コード例 #5
0
 def test_object_byte_tensor_conversion(self):
     test_obj = [1, "2", {3: 4}, [5]]
     test_obj_bytes = distributed.object_to_byte_tensor(test_obj)
     test_obj_dec = distributed.byte_tensor_to_object(test_obj_bytes)
     self.assertEqual(test_obj_dec, test_obj)