def get_parts(self, example: InputExample) -> FilledPattern: premise = self.remove_final_punc(self.shortenable(example.text_a)) choice1 = self.remove_final_punc( self.lowercase_first(example.meta['choice1'])) choice2 = self.remove_final_punc( self.lowercase_first(example.meta['choice2'])) question = example.meta['question'] assert question in ['cause', 'effect'] example.meta['choice1'], example.meta['choice2'] = choice1, choice2 num_masks = max( len(get_verbalization_ids(c, self.wrapper.tokenizer, False)) for c in [choice1, choice2]) if question == "cause": joiner = "because" else: joiner = "so" # searched patterns in fully-supervised learning # string_list_a = [choice1, 'or', choice2, '?', 'the', premise, joiner, 'the', self.mask] # string_list_a = [choice1, 'or', choice2, '?', premise, joiner, 'the', self.mask * num_masks] # string_list_a = ['"', choice1, '" or "', choice2, '"?', 'the', premise, 'the', joiner, self.mask*num_masks] # string_list_a = ['"', choice1, '" or "', choice2, '"?', premise, , joiner, 'the', self.mask*num_masks] # few-shot if self.pattern_id == 1: if question == "cause": string_list_a = [ choice1, 'or', choice2, '?', premise, 'because', 'the', self.mask * num_masks, '.' ] string_list_b = [] block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b elif question == "effect": string_list_a = [ choice1, 'or', choice2, '?', premise, 'so', 'the', self.mask * num_masks, '.' ] string_list_b = [] block_flag_a = [0, 0, 0, 0, 0, 0, 1, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b else: raise ValueError( "currently not support the kind of questions.") else: raise ValueError("unknown pattern_ids.")
def _build_mlm_logits_to_cls_logits_tensor(self): label_list = self.wrapper.config.label_list m2c_tensor = torch.ones([len(label_list), self.max_num_verbalizers], dtype=torch.long) * -1 for label_idx, label in enumerate(label_list): verbalizers = self.verbalize(label) for verbalizer_idx, verbalizer in enumerate(verbalizers): verbalizer_id = get_verbalization_ids(verbalizer, self.wrapper.tokenizer, force_single_token=True) assert verbalizer_id != self.wrapper.tokenizer.unk_token_id, "verbalization was tokenized as <UNK>" m2c_tensor[label_idx, verbalizer_idx] = verbalizer_id return m2c_tensor
def get_parts(self, example: InputExample) -> FilledPattern: pronoun = example.meta['span2_text'] target = example.meta['span1_text'] pronoun_idx = example.meta['span2_index'] words_a = example.text_a.split() words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*' text_a = ' '.join(words_a) text_a = self.shortenable(text_a) num_pad = self.rng.randint(0, 3) if 'train' in example.guid else 1 num_masks = len( get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False)) + num_pad masks = self.mask * num_masks # searched patterns in fully-supervised learning # string_list_a = [text_a, "the", "'*", pronoun, "*'", "the", masks] # string_list_a = [text_a, "the", "pronoun '*", pronoun, "*' refers to", masks] # string_list_a = [text_a, "the", "pronoun '*", pronoun, "*'", "the", masks] # string_list_a = [text_a, "the", "pronoun '*", pronoun, "*' refers to", "the", masks] # few-shot if self.pattern_id == 1: string_list_a = [ text_a, "the", "pronoun '*", pronoun, "*' refers to", masks + '.' ] string_list_b = [] block_flag_a = [0, 1, 0, 0, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b elif self.pattern_id == 2: string_list_a = [ "the", text_a, "the", "pronoun '*", pronoun, "*' refers to", masks + '.' ] string_list_b = [] block_flag_a = [1, 0, 1, 0, 0, 0, 0] block_flag_b = [] assert len(string_list_a) == len(block_flag_a) assert len(string_list_b) == len(block_flag_b) return string_list_a, string_list_b, block_flag_a, block_flag_b
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) for choice in ['choice1', 'choice2']: choice_text = input_example.meta[choice] choice_token_ids = get_verbalization_ids(choice_text, self.wrapper.tokenizer, force_single_token=False) mask_end = mask_start + len(choice_token_ids) input_features.meta[f'{choice}_token_ids'] = [-100] * \ len(input_features.input_ids) input_features.meta[f'{choice}_token_ids'][ mask_start:mask_end] = choice_token_ids
def add_special_input_features(self, input_example: InputExample, input_features: InputFeatures) -> None: mask_start = input_features.input_ids.index( self.wrapper.tokenizer.mask_token_id) num_masks = input_features.input_ids.count( self.wrapper.tokenizer.mask_token_id) mask_end = mask_start + num_masks target = input_example.meta['span1_text'] input_features.meta['target'] = target target_token_ids = get_verbalization_ids(target, self.wrapper.tokenizer, force_single_token=False) input_features.meta['target_token_ids'] = [-100] * \ len(input_features.input_ids) # we also predict <pad> tokens at the missing positions target_token_ids += [self.wrapper.tokenizer.pad_token_id] * \ (num_masks - len(target_token_ids)) input_features.meta['target_token_ids'][ mask_start:mask_end] = target_token_ids
def __init__(self, tokenizer, pvp, label_list): # Record prompt tokens pattern_token_set, pattern_token_indices = set(), [] # RoBERTa tokenizer is initiated from GPT2Tokenizer, # and it tokenizes same words differently in different positions: # e.g. 'Hello world!' -> ['Hello', 'Ġworld', '!']; # 'Hello', 'world' -> ['Hello'], ['world'] # So we need to add prefix space to simulate true situations kwargs = { 'add_prefix_space': True } if isinstance(tokenizer, GPT2Tokenizer) else {} for idx, part in enumerate(pvp.PATTERN): if pvp.BLOCK_FLAG[idx] == 1: token_ids = tokenizer.encode(part, add_special_tokens=False, **kwargs) pattern_token_set.update(token_ids) pattern_token_indices.extend(token_ids) # Record label tokens label_token_set = set() for label_idx, label in enumerate(label_list): verbalizers = pvp.verbalize(label) for verbalizer_idx, verbalizer in enumerate(verbalizers): verbalizer_id = get_verbalization_ids(verbalizer, tokenizer, force_single_token=True) assert verbalizer_id != tokenizer.unk_token_id, "verbalization was tokenized as <UNK>" label_token_set.add(verbalizer_id) assert len(pattern_token_set) < 50 and len(label_token_set) < 49 # Convert tokens in manual prompt / label to unused tokens # Note that `AlbertTokenizer` or `RobertaTokenizer` doesn't have a `vocab` attribute if hasattr(tokenizer, 'vocab') and '[unused0]' in tokenizer.vocab: # BERT self.pattern_convert = { token_id: tokenizer.vocab['[unused%s]' % idx] for idx, token_id in enumerate(pattern_token_set) } self.label_convert = { token_id: tokenizer.vocab['[unused%s]' % (idx + 50)] for idx, token_id in enumerate(label_token_set) } else: # ALBERT, RoBERTa start_idx = tokenizer.vocab_size - 100 self.pattern_convert = { token_id: start_idx + idx for idx, token_id in enumerate(pattern_token_set) } self.label_convert = { token_id: start_idx + 50 + idx for idx, token_id in enumerate(label_token_set) } # Convert mlm logits to cls logits self.vocab_size = tokenizer.vocab_size self.m2c_tensor = torch.tensor(list(self.label_convert.values()), dtype=torch.long) # Use lookup tensor to get replace embeddings self.lookup_tensor = torch.tensor( [self.pattern_convert[origin] for origin in pattern_token_indices], dtype=torch.long)