コード例 #1
0
ファイル: task_pvps.py プロジェクト: alibaba/EasyTransfer
    def get_parts(self, example: InputExample) -> FilledPattern:
        pronoun = example.meta['span2_text']
        target = example.meta['span1_text']
        pronoun_idx = example.meta['span2_index']

        words_a = example.text_a.split()
        words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
        text_a = ' '.join(words_a)
        text_a = self.shortenable(text_a)

        num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1
        num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad
        masks = self.mask * num_masks

        if self.pattern_id == 1:

            string_list_a = [text_a, "the", "pronoun '*", pronoun, "*' refers to",  masks + '.']
            string_list_b = []
            block_flag_a = [0, 1, 0, 0, 0, 0]
            block_flag_b = []
            assert len(string_list_a) == len(block_flag_a)
            assert len(string_list_b) == len(block_flag_b)
            return string_list_a, string_list_b, block_flag_a, block_flag_b

        elif self.pattern_id == 2:
            string_list_a = ["the", text_a, "the", "pronoun '*", pronoun, "*' refers to",  masks + '.']
            string_list_b = []
            block_flag_a = [1, 0, 1, 0, 0, 0, 0]
            block_flag_b = []
            assert len(string_list_a) == len(block_flag_a)
            assert len(string_list_b) == len(block_flag_b)
            return string_list_a, string_list_b, block_flag_a, block_flag_b
コード例 #2
0
ファイル: pvp.py プロジェクト: zeqiufan/pet
    def get_parts(self, example: InputExample) -> FilledPattern:
        pronoun = example.meta['span2_text']
        target = example.meta['span1_text']
        pronoun_idx = example.meta['span2_index']

        words_a = example.text_a.split()
        words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
        text_a = ' '.join(words_a)
        text_a = self.shortenable(text_a)

        num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1
        num_masks = len(
            get_verbalization_ids(target,
                                  self.wrapper.tokenizer,
                                  force_single_token=False)) + num_pad
        masks = self.mask * num_masks

        if self.pattern_id == 0:
            return [
                text_a, "The pronoun '*" + pronoun + "*' refers to",
                masks + '.'
            ], []
        elif self.pattern_id == 1:
            return [
                text_a, "In the previous sentence, the pronoun '*" + pronoun +
                "*' refers to", masks + '.'
            ], []
        elif self.pattern_id == 2:
            return [
                text_a,
                "Question: In the passage above, what does the pronoun '*" +
                pronoun + "*' refer to? Answer: ", masks + '.'
            ], []
コード例 #3
0
    def add_special_input_features(self, input_example: InputExample,
                                   input_features: InputFeatures) -> None:
        mask_start = input_features.input_ids.index(
            self.wrapper.tokenizer.mask_token_id)

        choices = input_example.meta['candidates']
        question_idx = input_example.meta['question_idx']

        input_features.meta['candidate_token_ids'] = []
        input_features.meta['candidate_labels'] = []
        input_features.meta['question_idx'] = question_idx

        self.original_choices[question_idx] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text,
                                                     self.wrapper.tokenizer,
                                                     force_single_token=False)
            choice_label = 1 if choice_text in input_example.meta[
                'answers'] else 0

            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids

            input_features.meta['candidate_token_ids'].append(
                candidate_token_ids)
            input_features.meta['candidate_labels'].append(choice_label)
            self.original_choices[question_idx].append(choice_text)
コード例 #4
0
ファイル: task_pvps.py プロジェクト: EliverQ/xp-tuning
    def get_parts(self, example: InputExample) -> FilledPattern:

        premise = self.remove_final_punc(self.shortenable(example.text_a))
        choice1 = self.remove_final_punc(
            self.lowercase_first(example.meta['choice1']))
        choice2 = self.remove_final_punc(
            self.lowercase_first(example.meta['choice2']))

        question = example.meta['question']
        assert question in ['cause', 'effect']

        example.meta['choice1'], example.meta['choice2'] = choice1, choice2
        num_masks = max(
            len(get_verbalization_ids(c, self.wrapper.tokenizer, False))
            for c in [choice1, choice2])

        if question == "cause":
            joiner = "because"
        else:
            joiner = "so"

        # searched patterns in fully-supervised learning
        # string_list_a = [choice1, 'or', choice2, '?', 'the', premise, joiner, 'the', self.mask]
        # string_list_a = [choice1, 'or', choice2, '?', premise, joiner, 'the', self.mask * num_masks]
        # string_list_a = ['"', choice1, '" or "', choice2, '"?', 'the', premise,  'the', joiner, self.mask*num_masks]
        # string_list_a = ['"', choice1, '" or "', choice2, '"?', premise,  , joiner, 'the', self.mask*num_masks]

        # few-shot
        if self.pattern_id == 1:
            if question == "cause":

                string_list_a = [
                    choice1, 'or', choice2, '?', premise, 'because', 'the',
                    self.mask * num_masks, '.'
                ]
                string_list_b = []
                block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0]
                block_flag_b = []
                assert len(string_list_a) == len(block_flag_a)
                assert len(string_list_b) == len(block_flag_b)
                return string_list_a, string_list_b, block_flag_a, block_flag_b

            elif question == "effect":

                string_list_a = [
                    choice1, 'or', choice2, '?', premise, 'so', 'the',
                    self.mask * num_masks, '.'
                ]
                string_list_b = []
                block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0]
                block_flag_b = []
                assert len(string_list_a) == len(block_flag_a)
                assert len(string_list_b) == len(block_flag_b)
                return string_list_a, string_list_b, block_flag_a, block_flag_b

            else:
                raise ValueError(
                    "currently not support the kind of questions.")
        else:
            raise ValueError("unknown pattern_ids.")
コード例 #5
0
    def add_special_input_features(self, input_example: InputExample,
                                   input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(
            self.wrapper.tokenizer.mask_token_id)

        if 'choices' in input_example.meta:
            choices = [choice for choice in input_example.meta['choices']]
        else:
            label_list = self.wrapper.config.label_list
            choices = [
                self.wrapper.preprocessor.pvp.verbalize(label)[0]
                for label in label_list
            ]

        input_features.meta['choice_token_ids'] = []

        for idx, choice_text in enumerate(choices):
            choice_token_ids = get_verbalization_ids(choice_text,
                                                     self.wrapper.tokenizer,
                                                     force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            candidate_token_ids = [-100] * len(input_features.input_ids)
            candidate_token_ids[mask_start:mask_end] = choice_token_ids
            input_features.meta['choice_token_ids'].append(candidate_token_ids)
コード例 #6
0
ファイル: pvp.py プロジェクト: timoschick/pet
    def get_parts(self, example: InputExample) -> FilledPattern:
        premise = self.shortenable(example.text_a)
        choices = example.meta['candidates']

        assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token'
        num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices)
        question = example.text_b.replace('@placeholder', self.mask * num_masks)
        return [premise, question], []
コード例 #7
0
ファイル: pvp.py プロジェクト: timoschick/pet
    def _build_mlm_logits_to_cls_logits_tensor(self):
        label_list = self.wrapper.config.label_list
        m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1

        for label_idx, label in enumerate(label_list):
            verbalizers = self.verbalize(label)
            for verbalizer_idx, verbalizer in enumerate(verbalizers):
                verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
                assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>"
                m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id
        return m2c_tensor
コード例 #8
0
    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)

        for choice in ['choice1', 'choice2']:
            choice_text = input_example.meta[choice]
            choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False)
            mask_end = mask_start + len(choice_token_ids)
            input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids)
            input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids
コード例 #9
0
ファイル: pvp.py プロジェクト: timoschick/pet
    def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \
            -> Tuple[List[int], List[int]]:
        """
        Encode an input example using this pattern-verbalizer pair.

        :param example: the input example to encode
        :param priming: whether to use this example for priming
        :param labeled: if ``priming=True``, whether the label should be appended to this example
        :return: A tuple, consisting of a list of input ids and a list of token type ids
        """

        if not priming:
            assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"

        tokenizer = self.wrapper.tokenizer  # type: PreTrainedTokenizer
        parts_a, parts_b = self.get_parts(example)

        kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {}

        parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a]
        parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x]

        if parts_b:
            parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b]
            parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x]

        self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length)

        tokens_a = [token_id for part, _ in parts_a for token_id in part]
        tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None

        if priming:
            input_ids = tokens_a
            if tokens_b:
                input_ids += tokens_b
            if labeled:
                mask_idx = input_ids.index(self.mask_id)
                assert mask_idx >= 0, 'sequence of input_ids must contain a mask token'
                assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label'
                verbalizer = self.verbalize(example.label)[0]
                verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True)
                input_ids[mask_idx] = verbalizer_id
            return input_ids, []

        input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
        token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)

        return input_ids, token_type_ids
コード例 #10
0
    def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None:
        if self.wrapper.config.wrapper_type == 'sequence_classifier':
            return

        mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id)
        num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id)
        mask_end = mask_start + num_masks

        target = input_example.meta['span1_text']
        input_features.meta['target'] = target
        target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)
        input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids)

        # we also predict <pad> tokens at the missing positions
        target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids))
        input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids
コード例 #11
0
ファイル: pvp.py プロジェクト: puraminy/pet
    def get_parts(self, example: InputExample) -> FilledPattern:
        task = self.wrapper.config.task_name
        vb = self.VERBALIZER[task]
        if not hasattr(self, "max_label_tokens"):
            self.max_label_tokens = max(
                len(get_verbalization_ids(l[0], self.wrapper.tokenizer, False))
                for l in vb.values())

        text_a = example.text_a
        text_b = example.text_b.rstrip(string.punctuation)
        if task in ["atomic_xintent"]:
            return [
                self.shortenable(text_a), self.mask * self.max_label_tokens,
                ' the person ',
                self.shortenable(text_b)
            ], []
        if self.pattern_id == 0:
            return [self.shortenable(text_a), '.'], [
                self.mask * self.max_label_tokens,
                self.shortenable(text_b)
            ]
コード例 #12
0
ファイル: pvp.py プロジェクト: timoschick/pet
    def get_parts(self, example: InputExample) -> FilledPattern:

        premise = self.remove_final_punc(self.shortenable(example.text_a))
        choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
        choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2']))

        question = example.meta['question']
        assert question in ['cause', 'effect']

        example.meta['choice1'], example.meta['choice2'] = choice1, choice2
        num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2])

        if question == 'cause':
            if self.pattern_id == 0:
                return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], []
            elif self.pattern_id == 1:
                return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], []
        else:
            if self.pattern_id == 0:
                return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], []
            elif self.pattern_id == 1:
                return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], []
コード例 #13
0
    def get_parts(self, example: InputExample) -> FilledPattern:

        premise = self.remove_final_punc(self.shortenable(example.text_a))
        choice1 = self.remove_final_punc(
            self.lowercase_first(example.meta["choice1"]))
        choice2 = self.remove_final_punc(
            self.lowercase_first(example.meta["choice2"]))

        question = example.meta["question"]
        assert question in ["cause", "effect"]

        example.meta["choice1"], example.meta["choice2"] = choice1, choice2
        num_masks = max(
            len(get_verbalization_ids(c, self.wrapper.tokenizer, False))
            for c in [choice1, choice2])

        if question == "cause":
            if self.pattern_id == 0:
                return [
                    '"', choice1, '" or "', choice2, '"?', premise, "because",
                    self.mask * num_masks, "."
                ], []
            elif self.pattern_id == 1:
                return [
                    choice1, "or", choice2, "?", premise, "because",
                    self.mask * num_masks, "."
                ], []
        else:
            if self.pattern_id == 0:
                return [
                    '"', choice1, '" or "', choice2, '"?', premise, ", so",
                    self.mask * num_masks, "."
                ], []
            elif self.pattern_id == 1:
                return [
                    choice1, "or", choice2, "?", premise, ", so",
                    self.mask * num_masks, "."
                ], []