示例#1
0
文件: dataset.py 项目: wzk1015/CNMT
    def add_answer_info(self, sample_info, sample):
        sample_has_caption = ('caption_str' in sample_info)
        if sample_has_caption:
            sample_info['answers'] = [sample_info['caption_str']]

        sample = super().add_answer_info(sample_info, sample)

        if sample_has_caption:
            sample.caption_str = enc_obj2bytes(sample_info['caption_str'])
            sample.ref_strs = enc_obj2bytes(sample_info['reference_strs'])
            sample.pop('gt_answers_enc')

        return sample
示例#2
0
    def add_answer_info(self, sample_info, sample):
        sample_has_answer = ("answers" in sample_info)
        if sample_has_answer:
            # Load real answers from sample_info
            answers = sample_info["answers"]
            sample.gt_answers_enc = enc_obj2bytes(answers)
            answer_processor_arg = {
                "answers": answers,
                "context_tokens": sample.context_tokens,
            }
            processed_answers = self.answer_processor(answer_processor_arg)

            assert not self.config.fast_read, \
                'In M4CTextVQADataset, online OCR sampling is incompatible ' \
                'with fast_read, so fast_read is currently not supported.'
            sample.targets = processed_answers["answers_scores"]
            sample.sampled_idx_seq = processed_answers["sampled_idx_seq"]
            sample.train_prev_inds = processed_answers["train_prev_inds"]
            sample.train_loss_mask = processed_answers["train_loss_mask"]
        else:
            # Load dummy answers as placeholders
            answer_params = self.config.processors.answer_processor.params
            sample.sampled_idx_seq = None
            sample.train_prev_inds = torch.zeros(answer_params.max_copy_steps,
                                                 dtype=torch.long)

        return sample
示例#3
0
    def add_sample_details(self, sample_info, sample):
        # 1. Load text (question words)
        # breaking change from VQA2Dataset:
        # load the entire question string, not tokenized questions, since we
        # switch to BERT tokenizer in M4C and do online tokenization
        question_str = (sample_info['question'] if 'question' in sample_info
                        else sample_info['question_str'])
        processed_question = self.text_processor({"question": question_str})
        sample.text = processed_question['token_inds']
        sample.text_len = processed_question['token_num']

        # 2. Load object
        # object bounding box information
        sample.obj_bbox_coordinates = self.copy_processor(
            {"blob": sample_info["obj_normalized_boxes"]})["blob"]

        # 3. Load OCR
        assert self.use_ocr and self.use_ocr_info, \
            'use_ocr and use_ocr_info must be both True for M4CTextVQADataset'
        # Preprocess OCR tokens
        ocr_tokens = [
            self.ocr_token_processor({"text": token})["text"]
            for token in sample_info["ocr_tokens"]
        ]
        # Get FastText embeddings for OCR tokens
        context = self.context_processor({"tokens": ocr_tokens})
        sample.context = context["text"]
        sample.context_tokens = context["tokens"]
        sample.context_tokens_enc = enc_obj2bytes(context["tokens"])
        sample.context_feature_0 = context["text"]
        sample.context_info_0 = Sample()
        sample.context_info_0.max_features = context["length"]
        # Get PHOC embeddings for OCR tokens
        context_phoc = self.phoc_processor({"tokens": ocr_tokens})
        sample.context_feature_1 = context_phoc["text"]
        sample.context_info_1 = Sample()
        sample.context_info_1.max_features = context_phoc["length"]
        # OCR order vectors
        # TODO remove order_vectors -- it is no longer needed in M4C
        order_vectors = np.eye(len(sample.context_tokens), dtype=np.float32)
        order_vectors = torch.from_numpy(order_vectors)
        order_vectors[context["length"]:] = 0
        sample.order_vectors = order_vectors
        # OCR bounding box information
        if 'ocr_normalized_boxes' in sample_info:
            # New imdb format: OCR bounding boxes are already pre-computed
            max_len = self.config.processors.answer_processor.params.max_length
            sample.ocr_bbox_coordinates = self.copy_processor(
                {"blob":
                 sample_info['ocr_normalized_boxes']})["blob"][:max_len]
        else:
            # Old imdb format: OCR bounding boxes are computed on-the-fly
            # from ocr_info
            sample.ocr_bbox_coordinates = self.bbox_processor(
                {"info": sample_info["ocr_info"]})["bbox"].coordinates
        # sample.iou_info = box_iou(sample.obj_bbox_coordinates, sample.ocr_bbox_coordinates)

        return sample
示例#4
0
    def add_ocr_details(self, sample_info, sample):
        assert self.use_ocr and self.use_ocr_info, \
            'use_ocr and use_ocr_info must be both True for Dataset'
        # Preprocess OCR tokens
        ocr_tokens = [
            self.ocr_token_processor({"text": token})["text"]
            for token in sample_info["ocr_tokens"]
        ]
        # Get FastText embeddings for tokens
        context = self.context_processor({"tokens": ocr_tokens})
        sample.context = context["text"]  # torch.Size([50, 300])
        sample.context_tokens = context["tokens"]
        sample.context_tokens_enc = enc_obj2bytes(context["tokens"])
        sample.context_feature_0 = context["text"]
        sample.context_info_0 = Sample()
        sample.context_info_0.max_features = context["length"]
        # Get PHOC embeddings for OCR tokens
        context_phoc = self.phoc_processor({"tokens": ocr_tokens})
        sample.context_phoc = context_phoc["text"]
        sample.context_info_phoc = Sample()
        sample.context_info_phoc.max_features = context_phoc["length"]

        # if 'ocr_normalized_boxes' in sample_info:
        #     max_len = self.config.processors.answer_processor.params.max_length
        #     sample.ocr_bbox = self.copy_processor(
        #         {"blob": sample_info['ocr_normalized_boxes']}
        #     )["blob"][:max_len]
        if "ocr_info" in sample_info:
            sample.ocr_bbox = self.bbox_processor({
                "info":
                sample_info["ocr_info"],
                "feats":
                context["text"],
                "img_id":
                sample.image_id,
                "obj_bbox":
                sample.obj_bbox
            })["bbox"]

        return sample
示例#5
0
文件: dataset.py 项目: wzk1015/CNMT
    def add_sample_details(self, sample_info, sample):
        # 1. Load text (question words)
        # breaking change from VQA2Dataset:
        # load the entire question string, not tokenized questions, since we
        # switch to BERT tokenizer in M4C and do online tokenization
        question_str = (
            sample_info['question'] if 'question' in sample_info
            else sample_info['question_str']
        )
        processed_question = self.text_processor({"question": question_str})
        sample.text = processed_question['token_inds']
        sample.text_len = processed_question['token_num']

        # 2. Load object
        # object bounding box information
        sample.obj_bbox_coordinates = self.copy_processor(
            {"blob": sample_info["obj_normalized_boxes"]}
        )["blob"]

        # 3. Load OCR
        if not self.use_ocr:
            # remove all OCRs from the sample
            # (i.e. make an empty OCR list)
            sample_info['ocr_tokens'] = []
            sample_info['ocr_info'] = []
            if 'ocr_normalized_boxes' in sample_info:
                sample_info['ocr_normalized_boxes'] = np.zeros(
                    (0, 4), np.float32
                )
            # clear OCR visual features
            sample.image_feature_1 = torch.zeros_like(sample.image_feature_1)

        # Preprocess OCR tokens
        ocr_tokens = [
            self.ocr_token_processor({"text": token})["text"]
            for token in sample_info["ocr_tokens"]
        ]
        # Get FastText embeddings for OCR tokens
        context = self.context_processor({"tokens": ocr_tokens})
        sample.context = context["text"]
        sample.context_tokens = context["tokens"]
        sample.context_tokens_enc = enc_obj2bytes(context["tokens"])
        sample.context_feature_0 = context["text"]
        sample.context_info_0 = Sample()
        sample.context_info_0.max_features = context["length"]
        # Get PHOC embeddings for OCR tokens
        context_phoc = self.phoc_processor({"tokens": ocr_tokens})
        sample.context_feature_1 = context_phoc["text"]
        sample.context_info_1 = Sample()
        sample.context_info_1.max_features = context_phoc["length"]
        # OCR order vectors
        # TODO remove order_vectors -- it is no longer needed in M4C
        order_vectors = np.eye(len(sample.context_tokens), dtype=np.float32)
        order_vectors = torch.from_numpy(order_vectors)
        order_vectors[context["length"]:] = 0
        sample.order_vectors = order_vectors
        # OCR bounding box information
        if 'ocr_normalized_boxes' in sample_info:
            # New imdb format: OCR bounding boxes are already pre-computed
            max_len = self.config.processors.answer_processor.params.max_length
            sample.ocr_bbox_coordinates = self.copy_processor(
                {"blob": sample_info['ocr_normalized_boxes']}
            )["blob"][:max_len]
        else:
            # Old imdb format: OCR bounding boxes are computed on-the-fly
            # from ocr_info
            sample.ocr_bbox_coordinates = self.bbox_processor(
                {"info": sample_info["ocr_info"]}
            )["bbox"].coordinates
        
        max_len = self.config.processors.answer_processor.params.max_length
        sample.ocr_confidence = self.copy_processor(
            {"blob": np.expand_dims(np.array(sample_info['ocr_confidence'], dtype="float32"), 1)}
        )["blob"][:max_len]
        sample.ocr_tokens = ocr_tokens[:max_len]
        
        return sample