示例#1
0
class MultiClassFromFile(BaseProcessor):
    """Label processor for multi class cases where the labels are
    saved in a file.
    """
    def __init__(self, config: MultiClassFromFileConfig, *args, **kwargs):
        self.label_vocab = VocabDict(config.vocab_file, *args, **kwargs)

    def __call__(self, item: Union[Dict[str, Any], str]) -> Dict[str, Any]:
        if isinstance(item, collections.abc.Mapping):
            label = item["label"]
        else:
            label = item

        # Remove UNK by subtracting 1 from output
        # UNK will always be at 0 even if it is not in vocab as it is automatically
        # always added by vocab dict
        class_index = self.label_vocab.word2idx(label) - 1
        assert class_index != -1, f"{label} is not present in vocab file"

        return {"class_index": torch.tensor(class_index, dtype=torch.long)}
示例#2
0
class VQAAnswerProcessor(BaseProcessor):
    """Processor for generating answer scores for answers passed using VQA
    accuracy formula. Using VocabDict class to represent answer vocabulary,
    so parameters must specify "vocab_file". "num_answers" in parameter config
    specify the max number of answers possible. Takes in dict containing
    "answers" or "answers_tokens". "answers" are preprocessed to generate
    "answers_tokens" if passed.

    Args:
        config (DictConfig): Configuration for the processor

    Attributes:
        answer_vocab (VocabDict): Class representing answer vocabulary
    """

    DEFAULT_NUM_ANSWERS = 10

    def __init__(self, config, *args, **kwargs):
        if not hasattr(config, "vocab_file"):
            raise AttributeError("'vocab_file' argument required, but not "
                                 "present in AnswerProcessor's config")

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)
        self.PAD_IDX = self.answer_vocab.word2idx("<pad>")
        self.BOS_IDX = self.answer_vocab.word2idx("<s>")
        self.EOS_IDX = self.answer_vocab.word2idx("</s>")
        self.UNK_IDX = self.answer_vocab.UNK_INDEX

        # Set EOS to something not achievable if it is not there
        if self.EOS_IDX == self.UNK_IDX:
            self.EOS_IDX = len(self.answer_vocab)

        self.preprocessor = None

        if hasattr(config, "preprocessor"):
            self.preprocessor = Processor(config.preprocessor)

            if self.preprocessor is None:
                raise ValueError(
                    f"No processor named {config.preprocessor} is defined.")

        if hasattr(config, "num_answers"):
            self.num_answers = config.num_answers
        else:
            self.num_answers = self.DEFAULT_NUM_ANSWERS
            warnings.warn("'num_answers' not defined in the config. "
                          "Setting to default of {}".format(
                              self.DEFAULT_NUM_ANSWERS))

    def __call__(self, item):
        """Takes in dict with answers or answers_tokens, and returns back
        a dict with answers (processed), "answers_indices" which point to
        indices of the answers if present and "answers_scores" which represent
        VQA style scores for the answers.

        Args:
            item (Dict): Dict containing answers or answers_tokens

        Returns:
            Dict: Processed answers, indices and scores.

        """
        tokens = []

        if not isinstance(item, dict):
            raise TypeError("'item' passed to processor must be a dict")

        if "answer_tokens" in item:
            tokens = item["answer_tokens"]
        elif "answers" in item and item["answers"] is not None:
            if self.preprocessor is None:
                raise AssertionError("'preprocessor' must be defined if you "
                                     "don't pass 'answer_tokens'")

            tokens = [
                self.preprocessor({"text": answer})["text"]
                for answer in item["answers"]
            ]
        else:
            raise AssertionError("'answers' or 'answer_tokens' must be passed"
                                 " to answer processor in a dict")

        if len(tokens) != 0:
            tokens = self._increase_to_ten(tokens)

        answers_indices = torch.zeros(self.DEFAULT_NUM_ANSWERS,
                                      dtype=torch.long)
        answers_indices.fill_(self.answer_vocab.get_unk_index())

        for idx, token in enumerate(tokens):
            answers_indices[idx] = self.answer_vocab.word2idx(token)

        answers_scores = self.compute_answers_scores(answers_indices)

        return {
            "answers": tokens,
            "answers_indices": answers_indices,
            "answers_scores": answers_scores,
        }

    def get_vocab_size(self):
        """Get vocab size of the answer vocabulary. Can also include
        soft copy dynamic answer space size.

        Returns:
            int: size of the answer vocabulary

        """
        return self.answer_vocab.num_vocab

    def get_true_vocab_size(self):
        """True vocab size can be different from normal vocab size in some cases
        such as soft copy where dynamic answer space is added.

        Returns:
            int: True vocab size.

        """
        return self.answer_vocab.num_vocab

    def word2idx(self, word):
        """Convert a word to its index according to vocabulary

        Args:
            word (str): Word to be converted to index.

        Returns:
            int: Index of the word.

        """
        return self.answer_vocab.word2idx(word)

    def idx2word(self, idx):
        """Index to word according to the vocabulary.

        Args:
            idx (int): Index to be converted to the word.

        Returns:
            str: Word corresponding to the index.

        """
        return self.answer_vocab.idx2word(idx)

    def compute_answers_scores(self, answers_indices):
        """Generate VQA based answer scores for answers_indices.

        Args:
            answers_indices (torch.LongTensor): tensor containing indices of the answers

        Returns:
            torch.FloatTensor: tensor containing scores.

        """
        scores = torch.zeros(self.get_vocab_size(), dtype=torch.float)
        gt_answers = list(enumerate(answers_indices))
        unique_answers = set(answers_indices.tolist())

        for answer in unique_answers:
            accs = []
            for gt_answer in gt_answers:
                other_answers = [
                    item for item in gt_answers if item != gt_answer
                ]

                matching_answers = [
                    item for item in other_answers if item[1] == answer
                ]
                acc = min(1, float(len(matching_answers)) / 3)
                accs.append(acc)
            avg_acc = sum(accs) / len(accs)

            if answer != self.answer_vocab.UNK_INDEX:
                scores[answer] = avg_acc

        return scores

    def _increase_to_ten(self, tokens):
        while len(tokens) < self.DEFAULT_NUM_ANSWERS:
            tokens += tokens[:self.DEFAULT_NUM_ANSWERS - len(tokens)]

        return tokens
示例#3
0
class M4CAnswerProcessor(BaseProcessor):
    """
    Process a TextVQA answer for iterative decoding in M4C
    """
    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)

        self.answer_vocab = VocabDict(config.vocab_file, *args, **kwargs)
        self.PAD_IDX = self.answer_vocab.word2idx("<pad>")
        self.BOS_IDX = self.answer_vocab.word2idx("<s>")
        self.EOS_IDX = self.answer_vocab.word2idx("</s>")
        self.UNK_IDX = self.answer_vocab.UNK_INDEX

        # make sure PAD_IDX, BOS_IDX and PAD_IDX are valid (not <unk>)
        assert self.PAD_IDX != self.answer_vocab.UNK_INDEX
        assert self.BOS_IDX != self.answer_vocab.UNK_INDEX
        assert self.EOS_IDX != self.answer_vocab.UNK_INDEX
        assert self.PAD_IDX == 0

        self.answer_preprocessor = Processor(config.preprocessor)
        assert self.answer_preprocessor is not None

        self.num_answers = config.num_answers
        self.max_length = config.max_length
        self.max_copy_steps = config.max_copy_steps
        assert self.max_copy_steps >= 1

        self.match_answer_to_unk = False

    def tokenize(self, sentence):
        return sentence.split()

    def match_answer_to_vocab_ocr_seq(self,
                                      answer,
                                      vocab2idx_dict,
                                      ocr2inds_dict,
                                      max_match_num=20):
        """
        Match an answer to a list of sequences of indices
        each index corresponds to either a fixed vocabulary or an OCR token
        (in the index address space, the OCR tokens are after the fixed vocab)
        """
        num_vocab = len(vocab2idx_dict)

        answer_words = self.tokenize(answer)
        answer_word_matches = []
        for word in answer_words:
            # match answer word to fixed vocabulary
            matched_inds = []
            if word in vocab2idx_dict:
                matched_inds.append(vocab2idx_dict.get(word))
            # match answer word to OCR
            # we put OCR after the fixed vocabulary in the answer index space
            # so add num_vocab offset to the OCR index
            matched_inds.extend(
                [num_vocab + idx for idx in ocr2inds_dict[word]])
            if len(matched_inds) == 0:
                if self.match_answer_to_unk:
                    matched_inds.append(vocab2idx_dict.get("<unk>"))
                else:
                    return []
            answer_word_matches.append(matched_inds)

        # expand per-word matched indices into the list of matched sequences
        if len(answer_word_matches) == 0:
            return []
        idx_seq_list = [()]
        for matched_inds in answer_word_matches:
            idx_seq_list = [
                seq + (idx, ) for seq in idx_seq_list for idx in matched_inds
            ]
            if len(idx_seq_list) > max_match_num:
                idx_seq_list = idx_seq_list[:max_match_num]

        return idx_seq_list

    def get_vocab_size(self):
        answer_vocab_nums = self.answer_vocab.num_vocab
        answer_vocab_nums += self.max_length

        return answer_vocab_nums

    def get_true_vocab_size(self):
        return self.answer_vocab.num_vocab

    def compute_answer_scores(self, answers):
        gt_answers = list(enumerate(answers))
        unique_answers = sorted(set(answers))
        unique_answer_scores = [0] * len(unique_answers)
        for idx, unique_answer in enumerate(unique_answers):
            accs = []
            for gt_answer in gt_answers:
                other_answers = [
                    item for item in gt_answers if item != gt_answer
                ]
                matching_answers = [
                    item for item in other_answers if item[1] == unique_answer
                ]
                acc = min(1, float(len(matching_answers)) / 3)
                accs.append(acc)
            unique_answer_scores[idx] = sum(accs) / len(accs)
        unique_answer2score = {
            a: s
            for a, s in zip(unique_answers, unique_answer_scores)
        }
        return unique_answer2score

    def __call__(self, item):
        answers = item["answers"]

        if not answers:
            return {
                "sampled_idx_seq":
                None,
                "train_prev_inds":
                torch.zeros(self.max_copy_steps, dtype=torch.long),
            }

        answers = [
            self.answer_preprocessor({"text": a})["text"] for a in answers
        ]
        assert len(answers) == self.num_answers

        # Step 1: calculate the soft score of ground-truth answers
        unique_answer2score = self.compute_answer_scores(answers)

        # Step 2: fill the first step soft scores for tokens
        scores = torch.zeros(self.max_copy_steps,
                             self.get_vocab_size(),
                             dtype=torch.float)

        # match answers to fixed vocabularies and OCR tokens.
        ocr2inds_dict = defaultdict(list)
        for idx, token in enumerate(item["tokens"]):
            ocr2inds_dict[token].append(idx)
        answer_dec_inds = [
            self.match_answer_to_vocab_ocr_seq(a,
                                               self.answer_vocab.word2idx_dict,
                                               ocr2inds_dict) for a in answers
        ]

        # Collect all the valid decoding sequences for each answer.
        # This part (idx_seq_list) was pre-computed in imdb (instead of online)
        # to save time
        all_idx_seq_list = []
        for answer, idx_seq_list in zip(answers, answer_dec_inds):
            all_idx_seq_list.extend(idx_seq_list)
            # fill in the soft score for the first decoding step
            score = unique_answer2score[answer]
            for idx_seq in idx_seq_list:
                score_idx = idx_seq[0]
                # the scores for the decoding Step 0 will be the maximum
                # among all answers starting with that vocab
                # for example:
                # if "red apple" has score 0.7 and "red flag" has score 0.8
                # the score for "red" at Step 0 will be max(0.7, 0.8) = 0.8
                scores[0, score_idx] = max(scores[0, score_idx], score)

        # train_prev_inds is the previous prediction indices in auto-regressive
        # decoding
        train_prev_inds = torch.zeros(self.max_copy_steps, dtype=torch.long)
        # train_loss_mask records the decoding steps where losses are applied
        train_loss_mask = torch.zeros(self.max_copy_steps, dtype=torch.float)
        if len(all_idx_seq_list) > 0:
            # sample a random decoding answer sequence for teacher-forcing
            idx_seq = all_idx_seq_list[np.random.choice(len(all_idx_seq_list))]
            dec_step_num = min(1 + len(idx_seq), self.max_copy_steps)
            train_loss_mask[:dec_step_num] = 1.0

            train_prev_inds[0] = self.BOS_IDX
            for t in range(1, dec_step_num):
                train_prev_inds[t] = idx_seq[t - 1]
                score_idx = idx_seq[t] if t < len(idx_seq) else self.EOS_IDX
                scores[t, score_idx] = 1.0
        else:
            idx_seq = ()

        answer_info = {
            "answers": answers,
            "answers_scores": scores,
            "sampled_idx_seq": idx_seq,
            "train_prev_inds": train_prev_inds,
            "train_loss_mask": train_loss_mask,
        }
        return answer_info