def get_parts(self, example: InputExample) -> FilledPattern: pronoun = example.meta['span2_text'] target = example.meta['span1_text'] pronoun_idx = example.meta['span2_index'] words_a = example.text_a.split() words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' text_a = ' '.join(words_a) text_a = self.shortenable(text_a) num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1 num_masks = len(get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad masks = self.mask * num_masks if self.pattern_id == 1: string_list_a = [text_a, "the", "pronoun '*", pronoun, "*' refers to", masks + '.'] string_list_b = [] block_flag_a = [0, 1, 0, 0, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b elif self.pattern_id == 2: string_list_a = ["the", text_a, "the", "pronoun '*", pronoun, "*' refers to", masks + '.'] string_list_b = [] block_flag_a = [1, 0, 1, 0, 0, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b
def get_parts(self, example: InputExample) -> FilledPattern: pronoun = example.meta['span2_text'] target = example.meta['span1_text'] pronoun_idx = example.meta['span2_index'] words_a = example.text_a.split() words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' text_a = ' '.join(words_a) text_a = self.shortenable(text_a) num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1 num_masks = len( get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad masks = self.mask * num_masks if self.pattern_id == 0: return [ text_a, "The pronoun '*" + pronoun + "*' refers to", masks + '.' ], [] elif self.pattern_id == 1: return [ text_a, "In the previous sentence, the pronoun '*" + pronoun + "*' refers to", masks + '.' ], [] elif self.pattern_id == 2: return [ text_a, "Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to? Answer: ", masks + '.' ], []
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) choices = input_example.meta['candidates'] question_idx = input_example.meta['question_idx'] input_features.meta['candidate_token_ids'] = [] input_features.meta['candidate_labels'] = [] input_features.meta['question_idx'] = question_idx self.original_choices[question_idx] = [] for idx, choice_text in enumerate(choices): choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) choice_label = 1 if choice_text in input_example.meta[ 'answers'] else 0 mask_end = mask_start + len(choice_token_ids) candidate_token_ids = [-100] * len(input_features.input_ids) candidate_token_ids[mask_start:mask_end] = choice_token_ids input_features.meta['candidate_token_ids'].append( candidate_token_ids) input_features.meta['candidate_labels'].append(choice_label) self.original_choices[question_idx].append(choice_text)
def get_parts(self, example: InputExample) -> FilledPattern: premise = self.remove_final_punc(self.shortenable(example.text_a)) choice1 = self.remove_final_punc( self.lowercase_first(example.meta['choice1'])) choice2 = self.remove_final_punc( self.lowercase_first(example.meta['choice2'])) question = example.meta['question'] assert question in ['cause', 'effect'] example.meta['choice1'], example.meta['choice2'] = choice1, choice2 num_masks = max( len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) if question == "cause": joiner = "because" else: joiner = "so" # searched patterns in fully-supervised learning # string_list_a = [choice1, 'or', choice2, '?', 'the', premise, joiner, 'the', self.mask] # string_list_a = [choice1, 'or', choice2, '?', premise, joiner, 'the', self.mask * num_masks] # string_list_a = ['"', choice1, '" or "', choice2, '"?', 'the', premise, 'the', joiner, self.mask*num_masks] # string_list_a = ['"', choice1, '" or "', choice2, '"?', premise, , joiner, 'the', self.mask*num_masks] # few-shot if self.pattern_id == 1: if question == "cause": string_list_a = [ choice1, 'or', choice2, '?', premise, 'because', 'the', self.mask * num_masks, '.' ] string_list_b = [] block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b elif question == "effect": string_list_a = [ choice1, 'or', choice2, '?', premise, 'so', 'the', self.mask * num_masks, '.' ] string_list_b = [] block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b else: raise ValueError( "currently not support the kind of questions.") else: raise ValueError("unknown pattern_ids.")
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) if 'choices' in input_example.meta: choices = [choice for choice in input_example.meta['choices']] else: label_list = self.wrapper.config.label_list choices = [ self.wrapper.preprocessor.pvp.verbalize(label)[0] for label in label_list ] input_features.meta['choice_token_ids'] = [] for idx, choice_text in enumerate(choices): choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) mask_end = mask_start + len(choice_token_ids) candidate_token_ids = [-100] * len(input_features.input_ids) candidate_token_ids[mask_start:mask_end] = choice_token_ids input_features.meta['choice_token_ids'].append(candidate_token_ids)
def get_parts(self, example: InputExample) -> FilledPattern: premise = self.shortenable(example.text_a) choices = example.meta['candidates'] assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in choices) question = example.text_b.replace('@placeholder', self.mask * num_masks) return [premise, question], []
def _build_mlm_logits_to_cls_logits_tensor(self): label_list = self.wrapper.config.label_list m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1 for label_idx, label in enumerate(label_list): verbalizers = self.verbalize(label) for verbalizer_idx, verbalizer in enumerate(verbalizers): verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>" m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id return m2c_tensor
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id) for choice in ['choice1', 'choice2']: choice_text = input_example.meta[choice] choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) mask_end = mask_start + len(choice_token_ids) input_features.meta[f'{choice}_token_ids'] = [-100] * len(input_features.input_ids) input_features.meta[f'{choice}_token_ids'][mask_start:mask_end] = choice_token_ids
def encode(self, example: InputExample, priming: bool = False, labeled: bool = False) \ -> Tuple[List[int], List[int]]: """ Encode an input example using this pattern-verbalizer pair. :param example: the input example to encode :param priming: whether to use this example for priming :param labeled: if ``priming=True``, whether the label should be appended to this example :return: A tuple, consisting of a list of input ids and a list of token type ids """ if not priming: assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true" tokenizer = self.wrapper.tokenizer # type: PreTrainedTokenizer parts_a, parts_b = self.get_parts(example) kwargs = {'add_prefix_space': True} if isinstance(tokenizer, GPT2Tokenizer) else {} parts_a = [x if isinstance(x, tuple) else (x, False) for x in parts_a] parts_a = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_a if x] if parts_b: parts_b = [x if isinstance(x, tuple) else (x, False) for x in parts_b] parts_b = [(tokenizer.encode(x, add_special_tokens=False, **kwargs), s) for x, s in parts_b if x] self.truncate(parts_a, parts_b, max_length=self.wrapper.config.max_seq_length) tokens_a = [token_id for part, _ in parts_a for token_id in part] tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None if priming: input_ids = tokens_a if tokens_b: input_ids += tokens_b if labeled: mask_idx = input_ids.index(self.mask_id) assert mask_idx >= 0, 'sequence of input_ids must contain a mask token' assert len(self.verbalize(example.label)) == 1, 'priming only supports one verbalization per label' verbalizer = self.verbalize(example.label)[0] verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) input_ids[mask_idx] = verbalizer_id return input_ids, [] input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) return input_ids, token_type_ids
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: if self.wrapper.config.wrapper_type == 'sequence_classifier': return mask_start = input_features.input_ids.index(self.wrapper.tokenizer.mask_token_id) num_masks = input_features.input_ids.count(self.wrapper.tokenizer.mask_token_id) mask_end = mask_start + num_masks target = input_example.meta['span1_text'] input_features.meta['target'] = target target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False) input_features.meta['target_token_ids'] = [-100] * len(input_features.input_ids) # we also predict <pad> tokens at the missing positions target_token_ids += [self.wrapper.tokenizer.pad_token_id] * (num_masks - len(target_token_ids)) input_features.meta['target_token_ids'][mask_start:mask_end] = target_token_ids
def get_parts(self, example: InputExample) -> FilledPattern: task = self.wrapper.config.task_name vb = self.VERBALIZER[task] if not hasattr(self, "max_label_tokens"): self.max_label_tokens = max( len(get_verbalization_ids(l[0], self.wrapper.tokenizer, False)) for l in vb.values()) text_a = example.text_a text_b = example.text_b.rstrip(string.punctuation) if task in ["atomic_xintent"]: return [ self.shortenable(text_a), self.mask * self.max_label_tokens, ' the person ', self.shortenable(text_b) ], [] if self.pattern_id == 0: return [self.shortenable(text_a), '.'], [ self.mask * self.max_label_tokens, self.shortenable(text_b) ]
def get_parts(self, example: InputExample) -> FilledPattern: premise = self.remove_final_punc(self.shortenable(example.text_a)) choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1'])) choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2'])) question = example.meta['question'] assert question in ['cause', 'effect'] example.meta['choice1'], example.meta['choice2'] = choice1, choice2 num_masks = max(len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) if question == 'cause': if self.pattern_id == 0: return ['"', choice1, '" or "', choice2, '"?', premise, 'because', self.mask * num_masks, '.'], [] elif self.pattern_id == 1: return [choice1, 'or', choice2, '?', premise, 'because', self.mask * num_masks, '.'], [] else: if self.pattern_id == 0: return ['"', choice1, '" or "', choice2, '"?', premise, ', so', self.mask * num_masks, '.'], [] elif self.pattern_id == 1: return [choice1, 'or', choice2, '?', premise, ', so', self.mask * num_masks, '.'], []
def get_parts(self, example: InputExample) -> FilledPattern: premise = self.remove_final_punc(self.shortenable(example.text_a)) choice1 = self.remove_final_punc( self.lowercase_first(example.meta["choice1"])) choice2 = self.remove_final_punc( self.lowercase_first(example.meta["choice2"])) question = example.meta["question"] assert question in ["cause", "effect"] example.meta["choice1"], example.meta["choice2"] = choice1, choice2 num_masks = max( len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) if question == "cause": if self.pattern_id == 0: return [ '"', choice1, '" or "', choice2, '"?', premise, "because", self.mask * num_masks, "." ], [] elif self.pattern_id == 1: return [ choice1, "or", choice2, "?", premise, "because", self.mask * num_masks, "." ], [] else: if self.pattern_id == 0: return [ '"', choice1, '" or "', choice2, '"?', premise, ", so", self.mask * num_masks, "." ], [] elif self.pattern_id == 1: return [ choice1, "or", choice2, "?", premise, ", so", self.mask * num_masks, "." ], []